pytorch深度学习和入门实战(五)如何进行fine-tuning

标签: pytorch入门到熟练  pytorch  深度学习  机器学习  人工智能  神经网络

1.基本内容

1.1.什么是fine-tuning?

在实践中,由于数据集不够大,很少有人从头开始训练网络。常见的做法是使用预训练的网络(例如在ImageNet上训练的分类1000类的网络)来重新fine-tuning(也叫微调),或者当做特征提取器。

以下是常见的两类迁移学习场景:

1) 卷积网络当做特征提取器。使用在ImageNet上预训练的网络,去掉最后的全连接层,剩余部分当做特征提取器(例如AlexNet在最后分类器前,是4096维的特征向量)。这样提取的特征叫做CNN codes。得到这样的特征后,可以使用线性分类器(Liner SVM、Softmax等)来分类图像。

2 )Fine-tuning卷积网络。替换掉网络的输入层(数据),使用新的数据继续训练。Fine-tune时可以选择fine-tune全部层或部分层。通常,前面的层提取的是图像的通用特征(generic features)(例如边缘检测,色彩检测),这些特征对许多任务都有用。后面的层提取的是与特定类别有关的特征,因此fine-tune时常常只需要Fine-tuning后面的层。

预训练模型

在ImageNet上训练一个网络,即使使用多GPU也要花费很长时间。因此人们通常共享他们预训练好的网络,这样有利于其他人再去使用。例如,Caffe有预训练好的网络地址Model Zoo。

1.2.何时使用Fine-tune、如何使用?

决定如何使用迁移学习的因素有很多,这是最重要的只有两个:新数据集的大小、以及新数据和原数据集的相似程度。 有一点一定记住:网络前几层学到的是通用特征,后面几层学到的是与类别相关的特征。这里有使用的四个场景:

1、新数据集比较小且和原数据集相似。因为新数据集比较小,如果fine-tune可能会过拟合;又因为新旧数据集类似,我们期望他们高层特征类似,可以使用预训练网络当做特征提取器,用提取的特征训练线性分类器。

2、新数据集大且和原数据集相似。因为新数据集足够大,可以fine-tune整个网络。

3、新数据集小且和原数据集不相似。新数据集小,最好不要fine-tune,和原数据集不类似,最好也不使用高层特征。这时可是使用前面层的特征来训练SVM分类器。

4、新数据集大且和原数据集不相似。因为新数据集足够大,可以重新训练。但是实践中fine-tune预训练模型还是有益的。新数据集足够大,可以fine-tine整个网络。

1.3 实践建议

预训练模型的限制。使用预训练模型,受限于其网络架构。例如,你不能随意从预训练模型取出卷积层。但是因为参数共享,可以输入任意大小图像;卷积层和池化层对输入数据大小没有要求(只要步长stride fit),其输出大小和属于大小相关;全连接层对输入大小没有要求,输出大小固定。

学习率。与重新训练相比,fine-tune要使用更小的学习率。因为训练好的网络模型权重已经平滑,我们不希望太快扭曲(distort)它们(尤其是当随机初始化线性分类器来分类预训练模型提取的特征时)。

2.基本过程

以下以resnet网络为例,详解如何构建分类网络模型:
1 选择pytorch中已有model和预训练weight的模型。
2 观察模型最后full connect layer or classifaction layer 情况(一般都是以imagenet比赛的weight,所以模型class num = 1000)
3 构建自己的fc层,修改成自己分类的数据
4 一般训练的时候,采用pretrained=True, 下载、并利用已有的已训练好的模型weight
5 可以先让backbone的部分不训练,只训练最后的top layer(fc 层).,先进行一个粗略的训练。
6 在保存了新训练的weight后,再降权重全部编程trainable。load新的权重进行,全部layer的训练,从而提高自己的准确率

3.pytorch提供哪些model

参见官网
对于分类模型来说,主要有以下几种:(pytorch=1.6.0)
在这里插入图片描述
注意:

1.pytorch 不同版本之间提供的内容可能不一样
2.resnet只是一个大类,里面还有resnet18, resnet34, resnet152等不同的小类,具体情况需要核对一下。
3.有一些比较新的model是暂时没有 pretrained weight的,例如MNASNet
4. 不同模型的检测效果、体积大小也是各不相同,需要均衡考虑的。

4.基本代码

1 查看一下resnet152网络的fc层情况
from torchvision import  models

new_model = models.resnet152(pretrained=True)
print(new_model)

在这里插入图片描述
可以看到最后layer 是 fc layer, 并且 in_num = 2048, out_num=1000
我们为了让其更好的降维和训练你自己的分类,增加2个nn.Linear和1个drop,并替换到原来模型的fc layer。

2 详细代码
def model_define(fc_num=256, class_num=3, train_all =False):
	new_model = models.resnet152(pretrained=True)
	# 前面的backbone保持不变
	for param in new_model.parameters():
		param.requires_grad = False

	# 只是修改输出fc层,新加层是trainable
	fc_inputs = new_model.fc.in_features
	new_model.fc = nn.Sequential(
		nn.Linear(fc_inputs, fc_num),
		nn.ReLU(),
		nn.Dropout(0.4),
		nn.Linear(fc_num, class_num)
	)

	#  修改所有参数层
	if train_all:
		for param in new_model.parameters():
			param.requires_grad = True
		torch.load("./models/best_loss.pt")
	new_model = new_model.to(device)
	print("[INFO] Model Layer:  ", summary(new_model, (3, 224, 224)))
	return new_model

修改class_num等于你需要分类的数目就可以了。
另外,这边train_all=False 时候,只是训练fc layer, 并且保存权重为best_loss.pt
train_all=True的时候,加载best_loss.pt,并且放开所有layer进行训练,提高精读。

完整代码参见github:https://github.com/ztfmars/pytorch_practise

版权声明:本文为weixin_42237113原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_42237113/article/details/108969592

智能推荐

2018.8.27

2018.8.27...

HTML 表单元素的基本样式

HTML 表单元素的基本样式 原创 ixygj197875 发布于2018-02-22 17:48:53 阅读数 2296 收藏 更新于2018-05-20 15:35:58 分类专栏: 揭秘 CSS 揭秘 CSS 收起 表单元素主要包括 label、input、textarea、select、datalist、******、progress、meter、output等,以及对表单元素进行分组的 ...

php输出语句

php输出语句 常见的输出语句 echo(): 可以一次输出多个值,多个值之间用逗号分隔。echo是语言结构(language construct),而并不是真正的函数,因此不能作为表达式的一部分使用。 print(): 函数print()打印一个值(它的参数),如果字符串成功显示则返回true,否则返回false。 print_r(): 可以把字符串和数字简单地打印出来,而数组则以括起来的键和值...

工厂模式

简介 常见的实例化对象模式。 用工厂方法替代new操作的一种模式。 当我们使用new操作实例化对象时,调用构造函数完成初始化。若初始化仅是进行赋值等简单的操作,写入构造函数即可。但如果初始化时需要执行一长串复杂的代码,将多个工作装入一个方法,是不妥的。 创建实例与使用实例分离。将创建实例所需的大量初始化工作从基类的构造函数中分离出去。 简单工厂模式、工厂方法模式针对的是一个产品等级结构;而抽象工厂...

B1105 Spiral Matrix (画图)

B1105 Spiral Matrix (25分) //第一次只拿了21分 矩阵的长和宽,求最大因子,从sqrt(num)开始枚举. 每次循环一次,s++,t--,d--,r++ 测试点四运行超时,是因为输入一个数字的时候,需要直接输出这个数字。//1分 测试点二运行超时,最后一个数字不必再while循环一次,直接输出即可。//3分 最后一个测试点卡了好久/(ㄒoㄒ)/~~ 螺旋矩阵...

猜你喜欢

Java基础=>String,StringBuffer与StringBuilder的区别

字符串常量池 什么是字符串常量池? JVM为了减少字符串对象的重复创建,其维护了一块特殊的内存,这段内存被称为字符串常量池(存储在方法区中)。 具体实现 当代码中出现字符串时,JVM首先会对其进行检查。 如果字符串常量池中存在相同内容的字符串对象,如果有,则不再创建,直接返回这个对象的地址返回。 如果字符串常量池中不存在相同内容的字符串对象,则创建一个新的字符串对象并放入常量池,并返回新创建的字符...

java调用其他java项目的Https接口

项目中是这样的: 用户拿出二维码展示,让机器识别二维码, 机器调用开门的后台系统接口, 然后开门的后台系统接口需要调用管理系统的接口, 管理系统需要判断能不能开门.这两个系统是互相独立的.当时使用http调用是没有问题的.当时后来要求必须用https.废话不说,直接代码: 我的项目中调用的是 HttpsUtils.Get(utlStr) 这个接口 开门系统接口如下图:   管理系统的接口...

Hadoop1.2.1全分布式模式配置

一 集群规划 主机名            IP                               安装的软件 &nbs...

Go语言gin框架的安装

尝试安装了一下gin,把遇到的一些小问题来记录一下 安装步骤 首先来看看官方文档,链接点这里 可以看到安装步骤很简单,就一句话 在命令行中输入这句话运行等待就好。 问题来了,因为墙的问题,go get会很慢,所以命令行里面半天什么反应也没有,不要急,慢慢等着就会看到gin-gonic/gin这个目录出现 这个时候命令行还是没有结束,表示还在下一些东西。有的时候可能心急的人就停了(比如我),然后写个...

uni-app表单组件二

input(输入框) 属性名 类型 说明 平台差异 value String 输入框的初始内容 type String input 的类型 password Boolean(默认false) 是否是密码类型 placeholder String 输入框为空时占位符 placeholder-style String 指定 placeholder 的样式 placeholder-class Strin...