OpenAI提出Reptile:可扩展的元学习算法

2024-04-12 20:18

本文主要是介绍OpenAI提出Reptile:可扩展的元学习算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

选自OpenAI Blog

作者:ALEX NICHOL & JOHN SCHULMAN

机器之心编译


近日,OpenAI 发布了简单元学习算法 Reptile,该算法对一项任务进行重复采样、执行随机梯度下降、更新初始参数直到习得最终参数。该方法的性能可与 MAML(一种广泛应用的元学习算法)媲美,且比后者更易实现,计算效率更高。


元学习是学习如何学习的过程。元学习算法会学习任务的一个分布,每项任务都是学习问题,并输出快速学习器,学习器可从少量样本中学习并进行泛化。一个得到充分研究的元学习问题是 few-shot 分类,其中每项任务都是分类问题,学习器只能看到 1-5 个输入-输出样本(每个类别),之后学习器必须对新输入进行分类。下面,你可以尝试 OpenAI 的 1-shot 分类交互 Demo,其使用了 Reptile。



点击「Edit All」按钮,绘制三种不同的形状或符号,然后在后侧的输入区域绘制其中一个形状,就可以看到 Reptile 的分类效果。前三个图是标注样本:每个定义一类。最后的图表示未知样本,Reptile 输出其属于每个类别的概率。(请点击原文链接体验交互)


Reptile 的工作原理


和 MAML 类似,Reptile 会学习神经网络的参数初始化方法,以使神经网络可使用少量新任务数据进行调整。但是 MAML 通过梯度下降算法的计算图来展开微分计算过程,而 Reptile 在每个任务中执行标准形式的随机梯度下降(SGD):它不用展开计算图或计算任意二阶导数。因此 Reptile 比 MAML 所需的计算量和内存都更少。伪代码如下:



最后一步也可以把 Φ−W 作为梯度,将其插入如 Adam 等更复杂的优化器。


很令人震惊,该方法运行效果很好。如果 k=1,该算法对应「联合训练」(joint training):在多项任务上执行 SGD。尽管联合训练在很多情况下可以学到有用的初始化,但在 zero-shot 学习不可能出现的情况下(如输出标签是随机排列的)它能学习的很少。Reptile 要求 k>1,更新依赖于损失函数的高阶导数。正如 OpenAI 在论文中展示的那样,k>1 时 Reptile 的行为与 k=1(联合训练)时截然不同。


为了分析 Reptile 的工作原理,OpenAI 使用泰勒级数逼近更新。Reptile 更新最大化同一任务中不同小批量的梯度内积,以改善泛化效果。该发现可能在元学习之外也有影响,如解释 SGD 的泛化性能。OpenAI 的分析结果表明 Reptile 和 MAML 可执行类似的更新,包括具备不同权重的相同两个项。


在 OpenAI 的实验中,他们展示了 Reptile 和 MAML 在 Omniglot 和 Mini-ImageNet 基准上执行 few-shot 分类任务时具备类似的性能。Reptile 收敛速度更快,因为其更新具备更低的方差。OpenAI 关于 Reptile 的分析表明,我们可以使用不同的 SGD 梯度组合获取大量不同的算法。在下图中,假设我们在不同任务中使用不同批量大小的 SGD 执行 K 个更新步,产生 g_1,g_2,…,g_k k 个梯度。下图展示了在 Omniglot 上的学习曲线,且它由梯度的和作为元梯度而绘制出。g_2 对应一阶 MAML,即原版 MAML 论文提出的算法。由于方差缩减,使用更多的梯度会导致更快的学习或收敛。注意仅使用 g_1(对应 k=1)如预测那样在这个任务中没有什么提升,因为我们无法改进 zero-shot 的性能。



实现


实现的 GitHub 地址:https://github.com/openai/supervised-reptile


该实现应用 TensorFlow 进行相关的计算,代码可在 Omniglot 和 Mini-ImageNet 上复现。此外,OpenAI 也发布了一个更小的基于 JavaScript 的实现(https://github.com/openai/supervised-reptile/tree/master/web),其对使用 TensorFlow 预训练的模型进行了调整——以上 demo 就是基于此实现的。


最后,下面是一个 few-shot 回归的简单示例,预测 10(x,y) 对的随机正弦波。该示例基于 PyTorch:


  
  1. import numpy as np

  2. import torch

  3. from torch import nn, autograd as ag

  4. import matplotlib.pyplot as plt

  5. from copy import deepcopy

  6. seed = 0

  7. plot = True

  8. innerstepsize = 0.02 # stepsize in inner SGD

  9. innerepochs = 1 # number of epochs of each inner SGD

  10. outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization

  11. niterations = 30000 # number of outer updates; each iteration we sample one task and update on it

  12. rng = np.random.RandomState(seed)

  13. torch.manual_seed(seed)

  14. # Define task distribution

  15. x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points

  16. ntrain = 10 # Size of training minibatches

  17. def gen_task():

  18.    "Generate classification problem"

  19.    phase = rng.uniform(low=0, high=2*np.pi)

  20.    ampl = rng.uniform(0.1, 5)

  21.    f_randomsine = lambda x : np.sin(x + phase) * ampl

  22.    return f_randomsine

  23. # Define model. Reptile paper uses ReLU, but Tanh gives slightly better results

  24. model = nn.Sequential(

  25.    nn.Linear(1, 64),

  26.    nn.Tanh(),

  27.    nn.Linear(64, 64),

  28.    nn.Tanh(),

  29.    nn.Linear(64, 1),

  30. )

  31. def totorch(x):

  32.    return ag.Variable(torch.Tensor(x))

  33. def train_on_batch(x, y):

  34.    x = totorch(x)

  35.    y = totorch(y)

  36.    model.zero_grad()

  37.    ypred = model(x)

  38.    loss = (ypred - y).pow(2).mean()

  39.    loss.backward()

  40.    for param in model.parameters():

  41.        param.data -= innerstepsize * param.grad.data

  42. def predict(x):

  43.    x = totorch(x)

  44.    return model(x).data.numpy()

  45. # Choose a fixed task and minibatch for visualization

  46. f_plot = gen_task()

  47. xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]

  48. # Reptile training loop

  49. for iteration in range(niterations):

  50.    weights_before = deepcopy(model.state_dict())

  51.    # Generate task

  52.    f = gen_task()

  53.    y_all = f(x_all)

  54.    # Do SGD on this task

  55.    inds = rng.permutation(len(x_all))

  56.    for _ in range(innerepochs):

  57.        for start in range(0, len(x_all), ntrain):

  58.            mbinds = inds[start:start+ntrain]

  59.            train_on_batch(x_all[mbinds], y_all[mbinds])

  60.    # Interpolate between current weights and trained weights from this task

  61.    # I.e. (weights_before - weights_after) is the meta-gradient

  62.    weights_after = model.state_dict()

  63.    outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule

  64.    model.load_state_dict({name :

  65.        weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize

  66.        for name in weights_before})

  67.    # Periodically plot the results on a particular task and minibatch

  68.    if plot and iteration==0 or (iteration+1) % 1000 == 0:

  69.        plt.cla()

  70.        f = f_plot

  71.        weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation

  72.        plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))

  73.        for inneriter in range(32):

  74.            train_on_batch(xtrain_plot, f(xtrain_plot))

  75.            if (inneriter+1) % 8 == 0:

  76.                frac = (inneriter+1) / 32

  77.                plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))

  78.        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))

  79.        lossval = np.square(predict(x_all) - f(x_all)).mean()

  80.        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")

  81.        plt.ylim(-4,4)

  82.        plt.legend(loc="lower right")

  83.        plt.pause(0.01)

  84.        model.load_state_dict(weights_before) # restore from snapshot

  85.        print(f"-----------------------------")

  86.        print(f"iteration               {iteration+1}")

  87.        print(f"loss on plotted curve   {lossval:.3f}") # would be better to average loss ove


论文:Reptile: a Scalable Metalearning Algorithm 



地址:https://d4mucfpksywv.cloudfront.net/research-covers/reptile/reptile_update.pdf


摘要:本论文讨论了元学习问题,即存在任务的一个分布,我们希望找到能在该分布所采样的任务(模型未见过的任务)中快速学习的智能体。我们提出了一种简单元学习算法 Reptile,它会学习一种能在新任务中快速精调的参数初始化方法。Reptile 会重复采样一个任务,并在该任务上执行训练,且将初始化朝该任务的已训练权重方向移动。Reptile 不像同样学习初始化的 MAML,它并不要求在优化过程中是可微的,因此它更适合于需要很多更新步的优化问题。我们的研究发现,Reptile 在一些有具备完整基准的 few-shot 分类任务上表现良好。此外,我们还提供了一些理论性分析,以帮助理解 Reptile 的工作原理。


原文链接:https://blog.openai.com/reptile/



点击下方“阅读原文”了解更多信息
↓↓↓

这篇关于OpenAI提出Reptile:可扩展的元学习算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/898135

相关文章

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

golang字符串匹配算法解读

《golang字符串匹配算法解读》文章介绍了字符串匹配算法的原理,特别是Knuth-Morris-Pratt(KMP)算法,该算法通过构建模式串的前缀表来减少匹配时的不必要的字符比较,从而提高效率,在... 目录简介KMP实现代码总结简介字符串匹配算法主要用于在一个较长的文本串中查找一个较短的字符串(称为

通俗易懂的Java常见限流算法具体实现

《通俗易懂的Java常见限流算法具体实现》:本文主要介绍Java常见限流算法具体实现的相关资料,包括漏桶算法、令牌桶算法、Nginx限流和Redis+Lua限流的实现原理和具体步骤,并比较了它们的... 目录一、漏桶算法1.漏桶算法的思想和原理2.具体实现二、令牌桶算法1.令牌桶算法流程:2.具体实现2.1

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;