Shark源码分析(五):线性回归算法与Lasso回归

2024-04-27 00:48

本文主要是介绍Shark源码分析(五):线性回归算法与Lasso回归,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Shark源码分析(五):线性回归算法与Lasso回归

为什么上一篇还是三,这一篇就跳到五了呢?其实我们原来提到过:

=++

这里的模型与算法我们之前都已经提到过了,虽然只是介绍了一个基类,并没有涉及到其具体的实现。在这里我们就会揭开其真正面目了。『策略』我们还没有介绍过,其实就是目标函数,在前面一些较为简单的算法中并没有涉及到这块。为了整个逻辑的完整性,我还是打算将其放在前面来介绍。

这里我们所介绍的算法是线性回归算法。它是机器学习算法中非常基本的一个算法。这里就不对它进行过多的介绍了,之后应该会写一个博客来叙述。

首先给出一个示例代码,使得有一个整体的映像。

#include <shark/Data/Csv.h>
#include <shark/ObjectiveFunctions/Loss/SquaredLoss.h>
#include <shark/Algorithms/Trainers/LinearRegression.h>#include <iostream>using namespace shark;
using namespace std;int main(int argc, char **argv) {if(argc < 3) {cerr << "usage: " << argv[0] << " (file with inputs/independent variables) (file with outputs/dependent variables)" << endl;exit(EXIT_FAILURE);}Data<RealVector> inputs;Data<RealVector> labels;try {importCSV(inputs, argv[1], ' ');} catch (...) {cerr << "unable to read input data from file " <<  argv[1] << endl;exit(EXIT_FAILURE);}try {importCSV(labels, argv[2]);}catch (...) {cerr << "unable to read labels from file " <<  argv[2] << endl;exit(EXIT_FAILURE);}RegressionDataset data(inputs, labels);// trainer and modelLinearRegression trainer;LinearModel<> model;// train modeltrainer.train(model, data);// show model parameterscout << "intercept: " << model.offset() << endl;cout << "matrix: " << model.matrix() << endl;SquaredLoss<> loss;Data<RealVector> prediction = model(data.inputs()); cout << "squared loss: " << loss(data.labels(), prediction) << endl;
}

首先读取算法所需要的数据集,这里是存储在LabeledData所特化的RegressionDataset中。之后就是初始化算法所对应的模型类,以及算法的训练方法类。利用训练方法类对训练数据进行训练,将训练所得的参数写回到对应的模型中。这里的prediction就是对于数据的预测值。模型重载了括号运算符,里面包含的内容是eval函数,就是计算其输出值。最后利用了平方损失函数来衡量模型的性能。

LinearModel类

Shark中将线性回归算法归于线性模型这一大类中。线性模型是使用线性函数 f(x)=Ax+b 来进行预测的。存在两个特殊的情况是:一是输出可能只是一个单独的数;二是,偏移b可能会被省略。

该文件位于<include/shark/Models/LinearModel.h>中。

template <class InputType = RealVector>
class LinearModel : public AbstractModel<InputType,RealVector>
{
private:typedef AbstractModel<InputType,RealVector> base_type;typedef LinearModel<InputType> self_type;RealMatrix m_matrix; // 权值矩阵RealVector m_offset; // 偏置向量
public:typedef typename base_type::BatchInputType BatchInputType;typedef typename base_type::BatchOutputType BatchOutputType;Lin

这篇关于Shark源码分析(五):线性回归算法与Lasso回归的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/939110

相关文章

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

【数据结构】——原来排序算法搞懂这些就行,轻松拿捏

前言:快速排序的实现最重要的是找基准值,下面让我们来了解如何实现找基准值 基准值的注释:在快排的过程中,每一次我们要取一个元素作为枢纽值,以这个数字来将序列划分为两部分。 在此我们采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,将中间数作为枢纽值。 快速排序实现主框架: //快速排序 void QuickSort(int* arr, int left, int rig

poj 3974 and hdu 3068 最长回文串的O(n)解法(Manacher算法)

求一段字符串中的最长回文串。 因为数据量比较大,用原来的O(n^2)会爆。 小白上的O(n^2)解法代码:TLE啦~ #include<stdio.h>#include<string.h>const int Maxn = 1000000;char s[Maxn];int main(){char e[] = {"END"};while(scanf("%s", s) != EO

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL