【机器学习基础】线性回归

                                                                                     线性回归

    1.线性回归简介

    线性回归(Linear Regression)是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。这句话对线性回归的解释是出自百度百科,简单一点来说,回归在数学上来说就是给定一组点集,能够用一条曲线去拟合,如果这条曲线是直线,那么就被称为线性回归,如果曲线是一条二次曲线,那么就被称为二次回归,当然还有一些其他的回归,比如局部加权回归、Logistic回归等等。

    本文针对ML中最基础的线性回归来展开,展开公式推导的每一个步骤,并且附上Python3的代码实现和使用Matplotlib进行可视化展示。

 

    2.线性回归公式推导

    线性回归首先假设数据是线性的关系,并且可以使用线性方程去拟合这些数据。回归的目的在于预测数值型的目标值,去拟合这些数据的线性方程就是所谓的回归方程(Regression Equation),自变量前面的系数为回归系数。

    先考虑一种最简单的情况,然后我们再推广到更一般的情况。比如我们需要预测房屋的售价,然后我们现在只有一个特征,就是房屋的居住面积。那么我们就可以构造一个估计函数:

                                                                               

    其中 是根据参数 来估计的值,训练的过程中,我们的目的就是要不断地调整两个参数,然后让这个估计值尽可能地接近真实值 ,所以基于这个原则,现在可以写出损失函数:

                                                                              

    将fxi 带入,

                                                                             

    从这个损失函数可以看出,这其实就是预测值与实际值的差值的平方,即计算的欧式距离。计算这个最小值的方法我们可以用最小二乘法,直观上来理解,就是想找到那么一条直线,使得所有的点离这条直线的距离之和最小。接下来就是对ω,b 求偏导数,然后使其等于0,求解方程。

                                                                             

                                                                             

                                                                             

                                                                             

 

    得到 的方程之后,我们就可以使用训练集的数据来求得 的值。

    现在我们推广到更一般的情况:假设有一组训练集(m个样本):   ,对于其中每一个样本 ,有n个特征,我们用 来描述其中的特征,用矩阵X来表示这m个样本:

                                                                             

      用Y表示标签矩阵:

                                                                             

     为了构造线性模型,还需要一些参数:

                                                                             

    有的为了添加偏移值会加一个 ,这里为了方便没有添加,实际结果是一样的。回归方程可以表达成如下形式,其中的 就是回归系数,直观上可以看出 的大小决定了这个特征的所占权重的大小。

                                                                            

     和上面一样,写出损失函数:

                                                                             

     这个公式和上面的相比没有求和符号是因为这里的 本身就是矩阵,这里多了一个12 也是为了在求导的时候方便处理。接下来,为了方便对矩阵进行求导,这里先把这个式子进行展开,过程如下:

                                                                             

                                                                             

                                                                             

     接下来对 求偏导数,在这里是对矩阵求偏导数,跟之前的对变量求导数的运算法则有一些不一样,下面给出几个矩阵求导常用的公式,如果对矩阵求导不熟悉的可以先记一下:

                                                                              

     求偏导,过程如下:

                                                                              

                                                                                       

                                                                                        

     令偏导数为0,解得:

                                                                               

     有了这个公式,我们就可以用训练集来计算出 的值,但是这里有一个前提是 必须是可逆的,在实际情况下,这个矩阵有可能不可逆,就算是可逆的,如果数据量比较大的话,要求解矩阵的逆的计算量也非常大,因此在实际应用中往往应用梯度下降法来更新 的值。梯度下降法的求解过程后续我再补充。

 

    3.代码实现

       在Git仓库的LinearRegression目录下有测试该算法的训练集数据,文件名为ex0.txt。下列代码使用的python版本为python3.7.0,numpy版本为1.5.1。

import numpy as np
import matplotlib.pyplot as plt

# 获取数据
def loadData(filePath):
    if (filePath == ''): return
    fileHandler = open(filePath)
    dataSet = []
    labels = []
    for line in fileHandler.readlines():
        lineAttr = line.strip().split("\t")
        dataSet.append(lineAttr[0:lineAttr.__len__()-1])
        labels.append(lineAttr[-1])
    return dataSet, labels

# linear regression
def StandRegress(dataSet, labels):
    dataMatrix = np.mat(dataSet).astype("float64")    #先转换成矩阵的形式
    labelsMatrix = np.mat(labels).T.astype("float64")
    xTx = dataMatrix.T*dataMatrix
    if np.linalg.det(xTx) == 0:     #若矩阵的行列式为0, 这一步要保证行列式不为0,因为要求矩阵的逆
        print("该矩阵行列式为0, 不可求逆。")
    W = xTx.I*(dataMatrix.T*labelsMatrix)
    return W

# 绘图
def drawResult(dataSet, labels, W):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    dataMatrix1 = np.mat(dataSet).astype("float64")  #这个是为了绘制训练样本数据
    dataMatrix2 = np.mat(dataSet).astype("float64")  #这个是为了绘制拟合的直线
    dataMatrix2.sort(0)                              #为了保证绘制直线时不出现混乱,先对点进行排序
    labelsMatrix1 = np.mat(labels).astype("float64") #这个是为了绘制训练样本数据
    labelsMatrix2 = dataMatrix2*W                    #这个是为了绘制拟合的直线
    ax.scatter(np.array(dataMatrix1[:,1].flatten().A[0]), labelsMatrix1.T.flatten().A[0])
    ax.plot(np.array(dataMatrix2[:,1]), labelsMatrix2.tolist())
    plt.show()


# 测试
dataSet, labels = loadData("ex0.txt")
W = StandRegress(dataSet, labels)
drawResult(dataSet, labels, W)
print(W)

 

  1. 结果可视化

拟合结果如下图所示,蓝色圆点为训练集数据,直线为拟合出来的最终结果。

 

  1. 参考材料

(1)《机器学习实战》 Peter Harrington

(2) 线性回归详解https://blog.csdn.net/qq_36330643/article/details/77649896

(3) 线性回归方程推导 https://blog.csdn.net/joob000/article/details/81295144

版权声明:本文为u010834867原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u010834867/article/details/89286549

智能推荐

bireme数据源同步工具--debezium+kafka+bireme

1、介绍 Bireme 是一个 Greenplum / HashData 数据仓库的增量同步工具。目前支持 MySQL、PostgreSQL 和 MongoDB 数据源 官方介绍文档:https://github.com/HashDataInc/bireme/blob/master/README_zh-cn.md 1、数据流 Bireme 采用 DELETE + COPY 的方式,将数据源的修改记...

一致性hash算法

散列(hash)在我看来就是一个数组,而与数组不同的点在于数组是按顺序写入的,而hash是按照一定的hash算法确定元素在数组中的位置的。hash最难的问题在于会有冲突出现,如果两个object根据相应的hash算法得出的值一样便产生了hash冲突。在所有解决hash冲突的方法中,我最欣赏的是链式解决法,即将hash到同一位置的元素用链表连接。当然还有其它几种处理hash冲突的算法,比如建立公共溢...

OpenCV-Python learning-1.安装,图片读取显示

1. OpenCV与OpenGL区别 https://www.zhihu.com/question/20212016 一个是让机器识别东西的,OpenCV是给电脑做眼睛的。 一个是让机器计算出更好画面的,OpenGL用在游戏渲染方面很多。 OpenCV(Open Source Computer Vision Library)是一个基于(开源)发行的跨平台计算机视觉库,OpenGL(全写Open G...

Mycat+Mysql分布式架构改造和性能压力测试

架构实现 Mycat作为数据库高可用中间件具备很多的功能,如负载均衡,分库分表,读写分离,故障迁移等。结合项目的实际情况,分库分表功能对于关联查询有很高的要求,需要从业务角度考虑分库分表后的关联查询SQL的分析,业务代码动作较大,所以在此方案中我们不考虑分库分表。主要应用Mycat的负载均衡及故障迁移的功能即可。 整个架构改造包括两个部分,第一是单例Mysql改为多个Mysql,同时负载均衡,并且...

人脸识别之疲劳检测(二)阈值法、KNN分类和K-means聚类

Table of Contents 1、均值法 2、中值法 3、KNN 4、K-means 结合上一节在获得人眼特征点后需要对睁眼闭眼状态做出判断,方法的选择需要经验结合公平的评价方法,使用大量测试集得到不同方法下的精确度并做出比较: 1、均值法 50帧睁眼数据取均值,得到不同阈值下精确度。 2、中值法 50帧睁眼数据取中值,得到不同阈值下精确度。 3、KNN KNN是一种ML常用分类算法,通过测...

猜你喜欢

CodeForce Tic-Tac-Toe

Two bears are playing tic-tac-toe via mail. It's boring for them to play usual tic-tac-toe game, so they are a playing modified version of this game. Here are its rules. The game is played on the foll...

Python雾里看花-抽象类ABC (abstract base class)

首先认识模块 abc,python中没有提供抽象类与抽象方法,然而提供了内置模块abc来模拟实现抽象类,例如提供泛映射类型的抽象类 abc.MutableMapping 继承abc.MutableMapping构造一个泛映射类型(类似python中的dict) 当然继承abc.Mapping 也可以,毕竟MutableMapping是其子类 dict是python中典型的映射类型数据结构,其接口的...

python 文件操作

2, with open (‘xx.txt’,‘w’,encoding=‘utf-8’) as f: f.write(‘文件内容或对象’)...

【Python基础】使用统计函数绘制简单图形

机器学习算法与自然语言处理出品 @公众号原创专栏作者 冯夏冲 学校 | 哈工大SCIR实验室在读博士生 2.1 函数bar 用于绘制柱状图 2.2 函数barh 用于绘制条形图 2.3 函数hist 用于绘制直方图 直方图与柱状图的区别 函数pie 用于绘制饼图 2.5 函数polor 用于绘制极线图 极线图是在极坐标系上绘出的一种图。在极坐标系中,要确定一个点,需要指明这个点距原点的角...

css:顶部按钮固定,上面内容滑动

这种需求我们平时见到很多的,实现方法也多的参差不齐,下面我说一种简单的。如图: 可以看到只有红线部分滚动,底下按钮是固定的。 代码...