PyTorch 简单易懂的 Embedding 和 EmbeddingBag - 解析与实践

2024-01-08 11:20

本文主要是介绍PyTorch 简单易懂的 Embedding 和 EmbeddingBag - 解析与实践,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

torch.nn子模块Sparse Layers详解

nn.Embedding

用途

主要参数

注意事项

使用示例

从预训练权重创建嵌入

nn.EmbeddingBag

功能和用途

主要参数

使用示例

从预训练权重创建

总结


torch.nn子模块Sparse Layers详解

nn.Embedding

torch.nn.Embedding 是 PyTorch 中一个重要的模块,用于创建一个简单的查找表,它存储固定字典和大小的嵌入(embeddings)。这个模块通常用于存储单词嵌入并使用索引检索它们。接下来,我将详细解释 Embedding 模块的用途、用法、特点以及如何使用它。

用途

  • 单词嵌入:在自然语言处理中,Embedding 模块用于将单词(或其他类型的标记)映射到一个高维空间,其中相似的单词在嵌入空间中彼此靠近。
  • 特征表示:在非自然语言处理任务中,嵌入可以用于任何类型的分类特征的密集表示。

主要参数

  • num_embeddings(int):嵌入字典的大小。
  • embedding_dim(int):每个嵌入向量的大小。
  • padding_idx(int,可选):如果指定,padding_idx 处的嵌入不会在训练中更新。
  • max_norm(float,可选):如果指定,将重新归一化超过此范数的嵌入向量。
  • norm_type(float,可选):用于max_norm选项的p-范数的p值,默认为2。
  • scale_grad_by_freq(bool,可选):如果为True,将按单词在批次中的频率的倒数来缩放梯度。
  • sparse(bool,可选):如果为True,权重矩阵的梯度将是一个稀疏张量。

注意事项

  • 当使用max_norm参数时,Embedding的前向方法会就地修改权重张量。如果需要对Embedding.weight进行梯度计算,则在调用前向方法前,需要在max_norm不为None时克隆它。
  • 仅有少数优化器支持稀疏梯度。

使用示例

import torch
import torch.nn as nn# 创建一个包含10个大小为3的嵌入的Embedding模块
embedding = nn.Embedding(10, 3)# 一个包含4个索引的2个样本的批次
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])# 通过Embedding模块获取嵌入
output = embedding(input)

此示例创建了一个嵌入字典大小为10、每个嵌入维度为3的 Embedding 模块。然后它接受一个包含索引的输入张量,并返回对应的嵌入向量。

从预训练权重创建嵌入

还可以使用from_pretrained类方法从预先训练的权重创建Embedding实例:

# 预训练的权重
weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])# 从预训练权重创建Embedding
embedding = nn.Embedding.from_pretrained(weight)# 获取索引1的嵌入
input = torch.LongTensor([1])
output = embedding(input)

在这个示例中,Embedding 模块是从一个给定的预训练权重张量创建的。这种方法在迁移学习或使用预先训练好的嵌入时非常有用。

nn.EmbeddingBag

torch.nn.EmbeddingBag 是 PyTorch 中一个高效的模块,用于计算“bags”(即序列或集合)的嵌入的总和或平均值,而无需实例化中间的嵌入。这个模块特别适用于处理具有不同长度的序列,如在自然语言处理任务中处理不同长度的句子或文档。下面我将详细介绍 EmbeddingBag 的功能、用法以及特点。

功能和用途

  • 高效计算EmbeddingBag 直接计算整个包的总和或平均值,比逐个嵌入后再求和或取平均更加高效。
  • 支持不同聚合方式:可以选择 "sum", "mean" 或 "max" 模式来聚合每个包中的嵌入。
  • 支持加权聚合EmbeddingBag 还支持为每个样本指定权重,在 "sum" 模式下进行加权求和。

主要参数

  • num_embeddings(int):嵌入字典的大小。
  • embedding_dim(int):每个嵌入向量的大小。
  • max_norm(float,可选):如果给定,将重新规范化超过此范数的嵌入向量。
  • mode(str,可选):聚合模式,可以是 "sum"、"mean" 或 "max"。
  • sparse(bool,可选):如果为True,权重矩阵的梯度将是一个稀疏张量。
  • padding_idx(int,可选):如果指定,padding_idx 处的嵌入将不会在训练中更新。

使用示例

import torch
import torch.nn as nn# 创建一个包含10个大小为3的嵌入的EmbeddingBag模块
embedding_bag = nn.EmbeddingBag(10, 3, mode='mean')# 一个示例包含4个索引的输入
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)# 指定每个包的开始索引
offsets = torch.tensor([0, 4], dtype=torch.long)# 通过EmbeddingBag模块获取嵌入
output = embedding_bag(input, offsets)

在这个示例中,创建了一个嵌入字典大小为10、每个嵌入维度为3的 EmbeddingBag 模块,并设置为 "mean" 模式。输入是一个索引序列,offsets 指定了每个包的开始位置。EmbeddingBag 会计算每个包的平均嵌入向量。

从预训练权重创建

EmbeddingBag 也可以从预训练的权重创建:

# 预训练的权重
weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])# 从预训练权重创建EmbeddingBag
embedding_bag = nn.EmbeddingBag.from_pretrained(weight)# 获取索引1的嵌入
input = torch.LongTensor([[1, 0]])
output = embedding_bag(input)

 这种方法在需要使用预先训练好的嵌入或在迁移学习中非常有用。EmbeddingBag 通过高效地处理不同长度的序列数据,在自然语言处理等领域中发挥着重要作用。

总结

 本篇博客探讨了 PyTorch 中的 nn.Embeddingnn.EmbeddingBag 两个关键模块,它们是处理和表示离散数据特征的强大工具。nn.Embedding 提供了一种有效的方式来将单词或其他类型的标记映射到高维空间中,而 nn.EmbeddingBag 以其独特的方式处理变长序列,通过聚合嵌入来提高计算效率。这两个模块不仅在自然语言处理中发挥关键作用,也适用于其他需要稠密特征表示的任务。此外,这些模块支持从预训练权重初始化,使其在迁移学习和复杂模型训练中极为重要。综上所述,nn.Embeddingnn.EmbeddingBag 是理解和应用 PyTorch 中嵌入层的基础。

这篇关于PyTorch 简单易懂的 Embedding 和 EmbeddingBag - 解析与实践的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/583294

相关文章

Agent开发核心技术解析以及现代Agent架构设计

《Agent开发核心技术解析以及现代Agent架构设计》在人工智能领域,Agent并非一个全新的概念,但在大模型时代,它被赋予了全新的生命力,简单来说,Agent是一个能够自主感知环境、理解任务、制定... 目录一、回归本源:到底什么是Agent?二、核心链路拆解:Agent的"大脑"与"四肢"1. 规划模

SpringBoot简单整合ElasticSearch实践

《SpringBoot简单整合ElasticSearch实践》Elasticsearch支持结构化和非结构化数据检索,通过索引创建和倒排索引文档,提高搜索效率,它基于Lucene封装,分为索引库、类型... 目录一:ElasticSearch支持对结构化和非结构化的数据进行检索二:ES的核心概念Index:

Python数据验证神器Pydantic库的使用和实践中的避坑指南

《Python数据验证神器Pydantic库的使用和实践中的避坑指南》Pydantic是一个用于数据验证和设置的库,可以显著简化API接口开发,文章通过一个实际案例,展示了Pydantic如何在生产环... 目录1️⃣ 崩溃时刻:当你的API接口又双叒崩了!2️⃣ 神兵天降:3行代码解决验证难题3️⃣ 深度

C++ move 的作用详解及陷阱最佳实践

《C++move的作用详解及陷阱最佳实践》文章详细介绍了C++中的`std::move`函数的作用,包括为什么需要它、它的本质、典型使用场景、以及一些常见陷阱和最佳实践,感兴趣的朋友跟随小编一起看... 目录C++ move 的作用详解一、一句话总结二、为什么需要 move?C++98/03 的痛点⚡C++

MySQL字符串转数值的方法全解析

《MySQL字符串转数值的方法全解析》在MySQL开发中,字符串与数值的转换是高频操作,本文从隐式转换原理、显式转换方法、典型场景案例、风险防控四个维度系统梳理,助您精准掌握这一核心技能,需要的朋友可... 目录一、隐式转换:自动但需警惕的&ld编程quo;双刃剑”二、显式转换:三大核心方法详解三、典型场景

GO语言实现串口简单通讯

《GO语言实现串口简单通讯》本文分享了使用Go语言进行串口通讯的实践过程,详细介绍了串口配置、数据发送与接收的代码实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要... 目录背景串口通讯代码代码块分解解析完整代码运行结果背景最近再学习 go 语言,在某宝用5块钱买了个

SQL 注入攻击(SQL Injection)原理、利用方式与防御策略深度解析

《SQL注入攻击(SQLInjection)原理、利用方式与防御策略深度解析》本文将从SQL注入的基本原理、攻击方式、常见利用手法,到企业级防御方案进行全面讲解,以帮助开发者和安全人员更系统地理解... 目录一、前言二、SQL 注入攻击的基本概念三、SQL 注入常见类型分析1. 基于错误回显的注入(Erro

SpringBoot整合Apache Spark实现一个简单的数据分析功能

《SpringBoot整合ApacheSpark实现一个简单的数据分析功能》ApacheSpark是一个开源的大数据处理框架,它提供了丰富的功能和API,用于分布式数据处理、数据分析和机器学习等任务... 目录第一步、添加android依赖第二步、编写配置类第三步、编写控制类启动项目并测试总结ApacheS

C++ 多态性实战之何时使用 virtual 和 override的问题解析

《C++多态性实战之何时使用virtual和override的问题解析》在面向对象编程中,多态是一个核心概念,很多开发者在遇到override编译错误时,不清楚是否需要将基类函数声明为virt... 目录C++ 多态性实战:何时使用 virtual 和 override?引言问题场景判断是否需要多态的三个关

C++简单日志系统实现代码示例

《C++简单日志系统实现代码示例》日志系统是成熟软件中的一个重要组成部分,其记录软件的使用和运行行为,方便事后进行故障分析、数据统计等,:本文主要介绍C++简单日志系统实现的相关资料,文中通过代码... 目录前言Util.hppLevel.hppLogMsg.hppFormat.hppSink.hppBuf