MAML 源代码解释说明 (一)

2024-04-03 08:38
文章标签 说明 解释 源代码 maml

本文主要是介绍MAML 源代码解释说明 (一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

元学习系列文章

  1. optimization based meta-learning
    1. 《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》 论文翻译笔记
    2. 元学习方向 optimization based meta learning 之 MAML论文详细解读
    3. MAML 源代码解释说明 (一):本篇博客
    4. MAML 源代码解释说明 (二)
    5. 元学习之《On First-Order Meta-Learning Algorithms》论文详细解读
    6. 元学习之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》论文详细解读
  2. metric based meta-learning: 待更新…
  3. model based meta-learning: 待更新…

前言

此篇是对 MAML 源代码的解释,作者开源了论文代码,但是代码中注释很少,刚开始不容易理清思路,所以对代码中的关键部分进行了解释说明,核心是 construct_model() 函数,里面包含了 MAML 的训练过程,看代码实现能够更清楚地理解作者的思想。

maml.py

maml.py

""" Code for the MAML algorithm and network definitions. """
from __future__ import print_function
import numpy as np
import sys
import tensorflow as tf
try:import special_grads
except KeyError as e:print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e,file=sys.stderr)from tensorflow.python.platform import flags
from utils import mse, xent, conv_block, normalizeFLAGS = flags.FLAGSclass MAML:def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):""" must call construct_model() after initializing MAML! """self.dim_input = dim_inputself.dim_output = dim_outputself.update_lr = FLAGS.update_lrself.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())self.classification = Falseself.test_num_updates = test_num_updatesif FLAGS.datasource == 'sinusoid':self.dim_hidden = [40, 40]self.loss_func = mseself.forward = self.forward_fcself.construct_weights = self.construct_fc_weightselif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet':self.loss_func = xentself.classification = Trueif FLAGS.conv:self.dim_hidden = FLAGS.num_filtersself.forward = self.forward_convself.construct_weights = self.construct_conv_weightselse:self.dim_hidden = [256, 128, 64, 64]self.forward=self.forward_fcself.construct_weights = self.construct_fc_weightsif FLAGS.datasource == 'miniimagenet':self.channels = 3else:self.channels = 1self.img_size = int(np.sqrt(self.dim_input/self.channels))else:raise ValueError('Unrecognized data source.')# ************************* 模型训练图的构建过程,此函数是核心代码 *************************def construct_model(self, input_tensors=None, prefix='metatrain_'):# a: training data for inner gradient, b: test data for meta gradientif input_tensors is None:self.inputa = tf.placeholder(tf.float32)self.inputb = tf.placeholder(tf.float32)self.labela = tf.placeholder(tf.float32)self.labelb = tf.placeholder(tf.float32)else:self.inputa = input_tensors['inputa']self.inputb = input_tensors['inputb']self.labela = input_tensors['labela']self.labelb = input_tensors['labelb']# 训练过程计算图with tf.variable_scope('model', reuse=None) as training_scope:# 如果不是第一次执行训练图, self 中则存在 self.weights 变量,那么所有的 tasks 都会共享这组 weightsif 'weights' in dir(self):training_scope.reuse_variables()weights = self.weightselse:# Define the weights# 第一次执行, weights 不在 dir(self) 中,则进行手动初始化self.weights = weights = self.construct_weights()# outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updateslossesa, outputas, lossesb, outputbs = [], [], [], []accuraciesa, accuraciesb = [], []num_updates = max(self.test_num_updates, FLAGS.num_updates)outputbs = [[]]*num_updateslossesb = [[]]*num_updatesaccuraciesb = [[]]*num_updatesdef task_metalearn(inp, reuse=True):""" Perform gradient descent for one task in the meta-batch.meta batch 个 task,并行执行 task_metalearn, 每个 task_metalearn 处理一个具体 task 的训练任务"""inputa, inputb, labela, labelb = inptask_outputbs, task_lossesb = [], []if self.classification:task_accuraciesb = []# inputa: [inner_batch, 1], task_outputa: [inner_batch, 1]task_outputa = self.forward(inputa, weights, reuse=reuse)  # only reuse on the first itertask_lossa = self.loss_func(task_outputa, labela)grads = tf.gradients(task_lossa, list(weights.values()))if FLAGS.stop_grad:grads = [tf.stop_gradient(grad) for grad in grads]# w1: g1, w2: g2gradients = dict(zip(weights.keys(), grads))# w1: w1 - α*g1, w2: w2 - α*g2,fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))# 使用更新后的 w, 在 inputb task 上再计算一次 metaoutput = self.forward(inputb, fast_weights, reuse=True)task_outputbs.append(output)task_lossesb.append(self.loss_func(output, labelb))# task 内部进行 num_updates 次更新,上面更新了一次,所以这里是 num_updates-1for j in range(num_updates - 1):loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)grads = tf.gradients(loss, list(fast_weights.values()))if FLAGS.stop_grad:grads = [tf.stop_gradient(grad) for grad in grads]gradients = dict(zip(fast_weights.keys(), grads))fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))output = self.forward(inputb, fast_weights, reuse=True)task_outputbs.append(output)task_lossesb.append(self.loss_func(output, labelb))# inputa 是训练集,inputb 是 task 的测试集# task_outputa 是第一次前向计算在 inputa 数据的输出,task_lossa 是基于 task_outputa 在参数 weight 上计算的 loss# task_outputbs 是每次梯度更新后参数在 inputb 数据上的输出,task_lossesb 是基于每个 task_outputb 计算出的 losstask_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]if self.classification:task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), tf.argmax(labela, 1))for j in range(num_updates):task_accuraciesb.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1)))task_output.extend([task_accuracya, task_accuraciesb])return task_outputif FLAGS.norm is not 'None':# to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)# [task_outputa, task_outputbs, task_lossa, task_lossesb]out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates]if self.classification:out_dtype.extend([tf.float32, [tf.float32]*num_updates])result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)if self.classification:outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = resultelse:outputas, outputbs, lossesa, lossesb = result## Performance & Optimization## 汇总 loss 函数作为 meta 的训练节点if 'train' in prefix:# lossesa 是 meta_batch_size 个具体任务在 inputa 数据上的第一次前向的 loss,self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)# lossesb[j] 是第 j 次更新时,meta_batch_size 个任务在 inputb 数据上的 lossself.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]# after the map_fnself.outputas, self.outputbs = outputas, outputbsif self.classification:self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)self.total_accuracies2 = total_accuracies2 = [tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]# pretrain 使用 inputa 数据上的第一次 loss 和,pretrain 相当于迁移学习的预训练self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)# metatrain, 使用 inputb 数据上最后一次前向计算出的平均 loss 作为优化目标if FLAGS.metatrain_iterations > 0:optimizer = tf.train.AdamOptimizer(self.meta_lr)# metatrain_op 最小化目标是每个 task 最后一次前向计算出的 loss 的平均值self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[FLAGS.num_updates-1])if FLAGS.datasource == 'miniimagenet':gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]self.metatrain_op = optimizer.apply_gradients(gvs)else:self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]if self.classification:self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)self.metaval_total_accuracies2 = total_accuracies2 =[tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]## Summaries# total_loss1 是各个任务在 meta-update 之前的平均 losstf.summary.scalar(prefix+'Pre-update loss', total_loss1)if self.classification:tf.summary.scalar(prefix+'Pre-update accuracy', total_accuracy1)for j in range(num_updates):tf.summary.scalar(prefix+'Post-update loss, step ' + str(j+1), total_losses2[j])if self.classification:tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), total_accuracies2[j])### Network construction functions (fc networks and conv networks)# 构建网络模型全连接层的参数def construct_fc_weights(self):weights = {}weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))for i in range(1,len(self.dim_hidden)):weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01))weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output]))return weights# 执行网络模型的前向计算过程def forward_fc(self, inp, weights, reuse=False):hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')for i in range(1,len(self.dim_hidden)):hidden = normalize(tf.matmul(hidden, weights['w'+str(i+1)]) + weights['b'+str(i+1)], activation=tf.nn.relu, reuse=reuse, scope=str(i+1))return tf.matmul(hidden, weights['w'+str(len(self.dim_hidden)+1)]) + weights['b'+str(len(self.dim_hidden)+1)]# 构建卷积层的参数def construct_conv_weights(self):weights = {}dtype = tf.float32conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)k = 3weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden], initializer=conv_initializer, dtype=dtype)weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))if FLAGS.datasource == 'miniimagenet':# assumes max poolingweights['w5'] = tf.get_variable('w5', [self.dim_hidden*5*5, self.dim_output], initializer=fc_initializer)weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')else:weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')return weights# 执行卷积层的前向计算def forward_conv(self, inp, weights, reuse=False, scope=''):# reuse is for the normalization parameters.channels = self.channelsinp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope+'0')hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope+'1')hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope+'2')hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope+'3')if FLAGS.datasource == 'miniimagenet':# last hidden layer is 6x6x64-ish, reshape to a vectorhidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])else:hidden4 = tf.reduce_mean(hidden4, [1, 2])return tf.matmul(hidden4, weights['w5']) + weights['b5']

这篇关于MAML 源代码解释说明 (一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/872429

相关文章

Linux中shell解析脚本的通配符、元字符、转义符说明

《Linux中shell解析脚本的通配符、元字符、转义符说明》:本文主要介绍shell通配符、元字符、转义符以及shell解析脚本的过程,通配符用于路径扩展,元字符用于多命令分割,转义符用于将特殊... 目录一、linux shell通配符(wildcard)二、shell元字符(特殊字符 Meta)三、s

java脚本使用不同版本jdk的说明介绍

《java脚本使用不同版本jdk的说明介绍》本文介绍了在Java中执行JavaScript脚本的几种方式,包括使用ScriptEngine、Nashorn和GraalVM,ScriptEngine适用... 目录Java脚本使用不同版本jdk的说明1.使用ScriptEngine执行javascript2.

Redis分布式锁使用及说明

《Redis分布式锁使用及说明》本文总结了Redis和Zookeeper在高可用性和高一致性场景下的应用,并详细介绍了Redis的分布式锁实现方式,包括使用Lua脚本和续期机制,最后,提到了RedLo... 目录Redis分布式锁加锁方式怎么会解错锁?举个小案例吧解锁方式续期总结Redis分布式锁如果追求

结构体和联合体的区别及说明

《结构体和联合体的区别及说明》文章主要介绍了C语言中的结构体和联合体,结构体是一种自定义的复合数据类型,可以包含多个成员,每个成员可以是不同的数据类型,联合体是一种特殊的数据结构,可以在内存中共享同一... 目录结构体和联合体的区别1. 结构体(Struct)2. 联合体(Union)3. 联合体与结构体的

关于SpringBoot的spring.factories文件详细说明

《关于SpringBoot的spring.factories文件详细说明》spring.factories文件是SpringBoot自动配置机制的核心部分之一,它位于每个SpringBoot自动配置模... 目录前言一、基本结构二、常见的键EnableAutoConfigurationAutoConfigu

Zookeeper安装和配置说明

一、Zookeeper的搭建方式 Zookeeper安装方式有三种,单机模式和集群模式以及伪集群模式。 ■ 单机模式:Zookeeper只运行在一台服务器上,适合测试环境; ■ 伪集群模式:就是在一台物理机上运行多个Zookeeper 实例; ■ 集群模式:Zookeeper运行于一个集群上,适合生产环境,这个计算机集群被称为一个“集合体”(ensemble) Zookeeper通过复制来实现

wolfSSL参数设置或配置项解释

1. wolfCrypt Only 解释:wolfCrypt是一个开源的、轻量级的、可移植的加密库,支持多种加密算法和协议。选择“wolfCrypt Only”意味着系统或应用将仅使用wolfCrypt库进行加密操作,而不依赖其他加密库。 2. DTLS Support 解释:DTLS(Datagram Transport Layer Security)是一种基于UDP的安全协议,提供类似于

git使用的说明总结

Git使用说明 下载安装(下载地址) macOS: Git - Downloading macOS Windows: Git - Downloading Windows Linux/Unix: Git (git-scm.com) 创建新仓库 本地创建新仓库:创建新文件夹,进入文件夹目录,执行指令 git init ,用以创建新的git 克隆仓库 执行指令用以创建一个本地仓库的

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

log4j2相关配置说明以及${sys:catalina.home}应用

${sys:catalina.home} 等价于 System.getProperty("catalina.home") 就是Tomcat的根目录:  C:\apache-tomcat-7.0.77 <PatternLayout pattern="%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %msg%n" /> 2017-08-10