持续学习动态架构算法LwF(Learning without Forgetting )解读总结与代码注释

本文主要是介绍持续学习动态架构算法LwF(Learning without Forgetting )解读总结与代码注释,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

0.持续学习

  • 持续学习相关文章汇总,包含论文地址、代码地址、具体分析解读地址

1.LwF算法相关链接

  • 论文地址
  • 代码地址

2.基本想法

  • 针对问题:在无法获得原始任务训练数据的情况下,适合使视觉系统适应新任务,并且保证其在旧任务上的性能
  • 问题建模:学习对新任务具有判别能力的参数,同时保留训练数据上原始任务的输出
  • 将网络分为所有任务共享部分和特定任务独享部分,网络架构如下:
    图片

3.损失函数

  • 待学习参数有三种:共享部分参数、旧任务们的独享参数、新任务独享参数
  • 由三部分组成:旧任务损失、新任务损失、正则化项
  • 旧任务损失:增长后的网络的输出与增长前的输出尽可能相同,采用知识蒸馏损失,类似交叉熵损失,只不过加大了较小概率的惩罚权重(其中关键参数T,要大于1来加大小概率的权重,文中通过网格搜索将其定位2)
  • 新任务损失:对于新任务的预测与真实值尽可能相同,使用交叉熵损失或者NLL损失
  • 正则化项:限制网络中所有参数,权重0.0005
  • 新旧任务权衡:在新任务损失前面有一个系数来表示对新旧任务性能的权衡,文中取1,参数越大,在新任务上的性能越好,在旧任务上的性能越差。通过改变该参数可以获得新旧任务性能曲线。

4.训练流程

  • 热身阶段(warm-up step):冻结共享部分参数、旧任务们的独享参数,单独训练新任务独享参数
  • 联合优化阶段(joint-optimize step):优化所有参数

5.特点

  • 与传统联合调优方法相比:无需存储旧任务的数据,新任务只需要通过一次共享层便可以用来进行旧任务和新任务的更新,却具有了联合调优的优点。但因为不同任务的分布会不相同,所以文中的方法效果会不如传统联合调优,传统联合调优的效果可以视为本文方法的上限。
  • 效率分析
    • 最慢:共享参数的正反向传播
    • 最快:特征提取层,因为只需要训练新任务的参数
    • 与传统微调相比:多了一步旧任务的独享参数更新,效率稍微低一点
    • 与传统联合调优相比:新旧任务共享的参数只需要进行一次前后向传播,效率更高

6.具体细节

  • 使用动量0.9的随机梯度下降
  • 在全连接层使用了dropout
  • 用旧任务的信息对新任务进行归一化
  • 数据增强:
    • 5X5的网格上对调整过大小的图像进行随机的固定尺寸裁剪
    • 随机镜像裁剪
    • RGB值上添加方差
  • 使用Xavier初始化新任务独享参数
  • 学习率是原网络学习率的0.1-0.02倍
  • 由于任务独享的特征提取部分参数量少,所以使用5倍学习率
  • 对于学习速度相似的方法,使用相同的训练epoch来进行公平比较
  • 有时为了防止过拟合、提升学习速度,会接近平稳在的时候将学习率变为0.1倍
  • 为了公平比较,将热身阶段后的共享网络作为联合训练和微调训练的起始点

7.实验

  • 添加单个新任务
  • 添加多个新任务
  • 数据集大小的影响
  • 网络设计的影响
  • 不同损失
  • 扩展网络结构的效用
  • 小学习率微调来保证旧任务的影响
  • 改变任务专属部分的网络层数

8.结论

  • 对于增长节点式的任务专属网络,其性能与原本的LwF性能相近,但是计算开销却大很多
  • 仅仅降低共享网络的学习率对保留旧任务性能的帮助并不大,但却会很大程度影响新任务
  • 用网络输出的变化来现在旧任务的变化要优于用网络参数的变化来衡量,因为网络参数一点小小的改变就可能引起输出巨大的改变
  • 知识蒸馏损失略优于L1、L2、交叉熵损失,但优势很小
  • 训练速度优于联合优化,对新任务的性能优于微调
  • 本文针对旧任务的损失对旧任务性能上的表现更可解释

9.未来工作

  • 应用到图像分类、跟踪等更多领域:分割、检测、视觉外的任务
  • 探索根据任务分布针对性地保留一些过去的任务数据和输出(由于是面对重尾分布)

10.代码解读

  • 参考文章
  • 含有备注的model.py
import torch
torch.backends.cudnn.benchmark=True
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from PIL import Image
from tqdm import tqdm
import time
import copyimport torchvision.models as models
import torchvision.transforms as transformsdef MultiClassCrossEntropy(logits, labels, T):# Ld = -1/N * sum(N) sum(C) softmax(label) * log(softmax(logit))labels = Variable(labels.data, requires_grad=False).cuda()outputs = torch.log_softmax(logits/T, dim=1)   # compute the log of softmax valueslabels = torch.softmax(labels/T, dim=1)# print('outputs: ', outputs)# print('labels: ', labels.shape)outputs = torch.sum(outputs * labels, dim=1, keepdim=False)outputs = -torch.mean(outputs, dim=0, keepdim=False)# print('OUT: ', outputs)return Variable(outputs.data, requires_grad=True).cuda()def kaiming_normal_init(m):if isinstance(m, nn.Conv2d):#判断m是不是nn.Conv2d的类型或子类nn.init.kaiming_normal_(m.weight, nonlinearity='relu')#一种初始化方法,要指明激活函数,保证输出有一定方差https://zhuanlan.zhihu.com/p/536483424elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')class Model(nn.Module):'''分为超参数、网络架构、类增加三个部分前向传播里没有softmax'''def __init__(self, classes, classes_map, args):# Hyper Parametersself.init_lr = args.init_lrself.num_epochs = args.num_epochsself.batch_size = args.batch_sizeself.lower_rate_epoch = [int(0.7 * self.num_epochs), int(0.9 * self.num_epochs)] #hardcoded decay scheduleself.lr_dec_factor = 10self.pretrained = Falseself.momentum = 0.9self.weight_decay = 0.0001# Constant to provide numerical stability while normalizingself.epsilon = 1e-16# Network architecturesuper(Model, self).__init__()self.model = models.resnet34(pretrained=self.pretrained)self.model.apply(kaiming_normal_init)"""独享层:一层全连接层,与classes数量有关,且没有偏置"""num_features = self.model.fc.in_featuresself.model.fc = nn.Linear(num_features, classes, bias=False)self.fc = self.model.fc'''共享层:resnet34除去最后一层'''#nn.Sequential按序列构建模型https://blog.csdn.net/hxxjxw/article/details/106231242#.children()返回模型的最外层,与.model()的区别类似于attend和extendself.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])#*用于迭代地取出list中的内容#用nn.DataParallel包装模型,可以在多GPU上运行https://zhuanlan.zhihu.com/p/647169457self.feature_extractor = nn.DataParallel(self.feature_extractor) # n_classes is incremented(递增) before processing new data in an iteration# n_known is set to n_classes after all data for an iteration has been processed数据处理完后n_known设为n_classesself.n_classes = 0self.n_known = 0self.classes_map = classes_mapdef forward(self, x):x = self.feature_extractor(x)x = x.view(x.size(0), -1)x = self.fc(x)return xdef increment_classes(self, new_classes):"""Add n classes in the final fc layer"""n = len(new_classes)print('new classes: ', n)in_features = self.fc.in_featuresout_features = self.fc.out_featuresweight = self.fc.weight.data#保存旧任务的网络权重if self.n_known == 0:new_out_features = nelse:new_out_features = out_features + nprint('new out features: ', new_out_features)self.model.fc = nn.Linear(in_features, new_out_features, bias=False)self.fc = self.model.fckaiming_normal_init(self.fc.weight)#所有任务网络统一初始化self.fc.weight.data[:out_features] = weight#还原旧任务网络权重self.n_classes += ndef classify(self, images):"""Classify images by softmaxArgs:x: input image batchReturns:preds: Tensor of size (batch_size,)"""_, preds = torch.max(torch.softmax(self.forward(images), dim=1), dim=1, keepdim=False)return predsdef update(self, dataset, class_map, args):self.compute_means = True# Save a copy to compute distillation outputs保存旧网络来计算旧任务原始输出prev_model = copy.deepcopy(self)prev_model.cuda()classes = list(set(dataset.train_labels))#print("Classes: ", classes)print('Known: ', self.n_known)if self.n_classes == 1 and self.n_known == 0:#self.n_classes初始值是1不是0吗?!new_classes = [classes[i] for i in range(1,len(classes))]else:new_classes = [cl for cl in classes if class_map[cl] >= self.n_known]#有新任务就动态调整网络if len(new_classes) > 0:self.increment_classes(new_classes)self.cuda()loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,shuffle=True, num_workers=12)print("Batch Size (for n_classes classes) : ", len(dataset))optimizer = optim.SGD(self.parameters(), lr=self.init_lr, momentum = self.momentum, weight_decay=self.weight_decay)with tqdm(total=self.num_epochs) as pbar:for epoch in range(self.num_epochs):# Modify learning rate# if (epoch+1) in lower_rate_epoch:# 	self.lr = self.lr * 1.0/lr_dec_factor# 	for param_group in optimizer.param_groups:# 		param_group['lr'] = self.lrfor i, (indices, images, labels) in enumerate(loader):seen_labels = []images = Variable(torch.FloatTensor(images)).cuda()seen_labels = torch.LongTensor([class_map[label] for label in labels.numpy()])labels = Variable(seen_labels).cuda()# indices = indices.cuda()optimizer.zero_grad()logits = self.forward(images)cls_loss = nn.CrossEntropyLoss()(logits, labels)if self.n_classes//len(new_classes) > 1:dist_target = prev_model.forward(images)logits_dist = logits[:,:-(self.n_classes-self.n_known)]dist_loss = MultiClassCrossEntropy(logits_dist, dist_target, 2)loss = dist_loss+cls_losselse:loss = cls_lossloss.backward()optimizer.step()if (i+1) % 1 == 0:tqdm.write('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' %(epoch+1, self.num_epochs, i+1, np.ceil(len(dataset)/self.batch_size), loss.data))pbar.update(1)

这篇关于持续学习动态架构算法LwF(Learning without Forgetting )解读总结与代码注释的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/481975

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

mybatis的整体架构

mybatis的整体架构分为三层: 1.基础支持层 该层包括:数据源模块、事务管理模块、缓存模块、Binding模块、反射模块、类型转换模块、日志模块、资源加载模块、解析器模块 2.核心处理层 该层包括:配置解析、参数映射、SQL解析、SQL执行、结果集映射、插件 3.接口层 该层包括:SqlSession 基础支持层 该层保护mybatis的基础模块,它们为核心处理层提供了良好的支撑。

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

百度/小米/滴滴/京东,中台架构比较

小米中台建设实践 01 小米的三大中台建设:业务+数据+技术 业务中台--从业务说起 在中台建设中,需要规范化的服务接口、一致整合化的数据、容器化的技术组件以及弹性的基础设施。并结合业务情况,判定是否真的需要中台。 小米参考了业界优秀的案例包括移动中台、数据中台、业务中台、技术中台等,再结合其业务发展历程及业务现状,整理了中台架构的核心方法论,一是企业如何共享服务,二是如何为业务提供便利。

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

第10章 中断和动态时钟显示

第10章 中断和动态时钟显示 从本章开始,按照书籍的划分,第10章开始就进入保护模式(Protected Mode)部分了,感觉从这里开始难度突然就增加了。 书中介绍了为什么有中断(Interrupt)的设计,中断的几种方式:外部硬件中断、内部中断和软中断。通过中断做了一个会走的时钟和屏幕上输入字符的程序。 我自己理解中断的一些作用: 为了更好的利用处理器的性能。协同快速和慢速设备一起工作

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个