UniversalTransformer with Adaptive Computation Time(ACT)

2024-01-04 01:52

本文主要是介绍UniversalTransformer with Adaptive Computation Time(ACT),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述


原论文链接:https://arxiv.org/abs/1807.03819


Main code

import torch
import numpy as npclass PositionTimestepEmbedding(torch.nn.Module):def forward(self, x, t):device = x.devicesequence_length = x.size(1)d_model = x.size(2)position_embedding = np.array([[pos / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)] for pos in range(sequence_length)])position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2])position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2])timestep_embedding = np.array([[t / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]])timestep_embedding[:, 0::2] = np.sin(timestep_embedding[:, 0::2])timestep_embedding[:, 1::2] = np.sin(timestep_embedding[:, 1::2])embedding = position_embedding + timestep_embeddingreturn x + torch.tensor(embedding, dtype=torch.float, requires_grad=False, device=device)class MultiHeadAttention(torch.nn.Module):def __init__(self, d_model, num_heads, dropout=0.):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_headsassert self.head_dim * num_heads == self.d_model, "d_model must be divisible by num_heads"self.query = torch.nn.Linear(d_model, d_model)self.key = torch.nn.Linear(d_model, d_model)self.value = torch.nn.Linear(d_model, d_model)self.dropout = torch.nn.Dropout(dropout)self.output = torch.nn.Linear(d_model, d_model)self.layer_norm = torch.nn.LayerNorm(d_model)def scaled_dot_product_attention(self, q, k, v, mask=None):scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)if mask is not None:scores = scores.masked_fill(mask, -np.inf)scores = scores.softmax(dim=-1)scores = self.dropout(scores)return torch.matmul(scores, v), scoresdef forward(self, q, k, v, mask=None):batch_size = q.size(0)residual = qif mask is not None:mask = mask.unsqueeze(1)q = self.query(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)k = self.key(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)v = self.value(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)out, scores = self.scaled_dot_product_attention(q, k, v, mask)out = (out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim))out = self.output(out)out += residualreturn self.layer_norm(out)class TransitionFunction(torch.nn.Module):def __init__(self, d_model, dim_transition, dropout=0.):super().__init__()self.linear1 = torch.nn.Linear(d_model, dim_transition)self.relu = torch.nn.ReLU()self.linear2 = torch.nn.Linear(dim_transition, d_model)self.dropout = torch.nn.Dropout(dropout)self.layer_norm = torch.nn.LayerNorm(d_model)def forward(self, x):y = self.linear1(x)y = self.relu(y)y = self.linear2(y)y = self.dropout(y)y = y + xreturn self.layer_norm(y)class EncoderBasicLayer(torch.nn.Module):def __init__(self, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)self.transition = TransitionFunction(d_model, dim_transition, dropout)def forward(self, block_inputs, enc_self_attn_mask=None):self_attention_outputs = self.self_attention(block_inputs, block_inputs, block_inputs, enc_self_attn_mask)block_outputs = self.transition(self_attention_outputs)return block_outputsclass DecoderBasicLayer(torch.nn.Module):def __init__(self, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)self.attention_enc_dec = MultiHeadAttention(d_model, num_heads, dropout)self.transition = TransitionFunction(d_model, dim_transition, dropout)def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask=None, dec_enc_attn_mask=None):dec_query = self.self_attention(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)block_outputs = self.attention_enc_dec(dec_query, enc_outputs, enc_outputs, dec_enc_attn_mask)block_outputs = self.transition(block_outputs)return block_outputsclass RecurrentEncoderBlock(torch.nn.Module):def __init__(self, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.layers = torch.nn.ModuleList([EncoderBasicLayer(d_model,dim_transition,num_heads,dropout) for _ in range(num_layers)])def forward(self, x, enc_self_attn_mask=None):for l in self.layers:x = l(x, enc_self_attn_mask)return xclass RecurrentDecoderBlock(torch.nn.Module):def __init__(self, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.layers = torch.nn.ModuleList([DecoderBasicLayer(d_model,dim_transition,num_heads,dropout) for _ in range(num_layers)])def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):for l in self.layers:dec_inputs = l(dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_inputsclass AdaptiveNetwork(torch.nn.Module):def __init__(self, d_model, dim_transition, epsilon, max_hop):super().__init__()self.threshold = 1.0 - epsilonself.max_hop = max_hopself.halting_predict = torch.nn.Sequential(torch.nn.Linear(d_model, dim_transition),torch.nn.ReLU(),torch.nn.Linear(dim_transition, 1),torch.nn.Sigmoid())def forward(self, x, mask, pos_time_embed, recurrent_block, encoder_output=None):device = x.devicehalting_probability = torch.zeros((x.size(0), x.size(1)), device=device)remainders = torch.zeros((x.size(0), x.size(1)), device=device)n_updates = torch.zeros((x.size(0), x.size(1)), device=device)previous = torch.zeros_like(x, device=device)step = 0while (((halting_probability < self.threshold) & (n_updates < self.max_hop)).byte().any()):x = x + pos_time_embed(x, step)p = self.halting_predict(x).squeeze(-1)still_running = (halting_probability < 1.0).float()new_halted = (halting_probability + p * still_running > self.threshold).float() * still_runningstill_running = (halting_probability + p * still_running <= self.threshold).float() * still_runninghalting_probability = halting_probability + p * still_runningremainders = remainders + new_halted * (1 - halting_probability)halting_probability = halting_probability + new_halted * remaindersn_updates = n_updates + still_running + new_haltedupdate_weights = p * still_running + new_halted * remaindersif encoder_output is not None:x = recurrent_block(x, encoder_output, mask[0], mask[1])else:x = recurrent_block(x, mask)previous = ((x * update_weights.unsqueeze(-1)) + (previous * (1 - update_weights.unsqueeze(-1))))step += 1return previousclass Encoder(torch.nn.Module):def __init__(self, epsilon, max_hop, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()assert 0 < epsilon < 1, "0 < epsilon < 1 !!!"self.pos_time_embedding = PositionTimestepEmbedding()self.recurrent_block = RecurrentEncoderBlock(num_layers,d_model,dim_transition,num_heads,dropout)self.adaptive_network = AdaptiveNetwork(d_model, dim_transition, epsilon, max_hop)def forward(self, x, enc_self_attn_mask=None):return self.adaptive_network(x, enc_self_attn_mask, self.pos_time_embedding, self.recurrent_block)class Decoder(torch.nn.Module):def __init__(self, epsilon, max_hop, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()assert 0 < epsilon < 1, "0 < epsilon < 1 !!!"self.pos_time_embedding = PositionTimestepEmbedding()self.recurrent_block = RecurrentDecoderBlock(num_layers,d_model,dim_transition,num_heads,dropout)self.adaptive_network = AdaptiveNetwork(d_model, dim_transition, epsilon, max_hop)def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):return self.adaptive_network(dec_inputs, (dec_self_attn_mask, dec_enc_attn_mask),self.pos_time_embedding, self.recurrent_block, enc_outputs)class AdaptiveComputationTimeUniversalTransformer(torch.nn.Module):def __init__(self, d_model, dim_transition, num_heads, enc_attn_layers, dec_attn_layers, epsilon, max_hop, dropout=0.):super().__init__()self.encoder = Encoder(epsilon, max_hop, enc_attn_layers, d_model, dim_transition, num_heads, dropout)self.decoder = Decoder(epsilon, max_hop, dec_attn_layers, d_model, dim_transition, num_heads, dropout)def forward(self, src, tgt, enc_self_attn_mask=None, dec_self_attn_mask=None, dec_enc_attn_mask=None):enc_outputs = self.encoder(src, enc_self_attn_mask)return self.decoder(tgt, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

Mask

# from https://zhuanlan.zhihu.com/p/403433120
def get_attn_subsequence_mask(seq):  # seq: [batch_size, tgt_len]attn_shape = [seq.size(0), seq.size(1), seq.size(1)]subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]subsequence_mask = torch.from_numpy(subsequence_mask).bool()  # [batch_size, tgt_len, tgt_len]return subsequence_maskdef get_attn_pad_mask(seq_q, seq_k):  # seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len]batch_size, len_q = seq_q.size()batch_size, len_k = seq_k.size()pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]return pad_attn_mask.expand(batch_size, len_q, len_k)

这篇关于UniversalTransformer with Adaptive Computation Time(ACT)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/567804

相关文章

如何使用 Bash 脚本中的time命令来统计命令执行时间(中英双语)

《如何使用Bash脚本中的time命令来统计命令执行时间(中英双语)》本文介绍了如何在Bash脚本中使用`time`命令来测量命令执行时间,包括`real`、`user`和`sys`三个时间指标,... 使用 Bash 脚本中的 time 命令来统计命令执行时间在日常的开发和运维过程中,性能监控和优化是不

linux 下Time_wait过多问题解决

转自:http://blog.csdn.net/jaylong35/article/details/6605077 问题起因: 自己开发了一个服务器和客户端,通过短连接的方式来进行通讯,由于过于频繁的创建连接,导致系统连接数量被占用,不能及时释放。看了一下18888,当时吓到了。 现象: 1、外部机器不能正常连接SSH 2、内向外不能够正常的ping通过,域名也不能正常解析。

python内置模块datetime.time类详细介绍

​​​​​​​Python的datetime模块是一个强大的日期和时间处理库,它提供了多个类来处理日期和时间。主要包括几个功能类datetime.date、datetime.time、datetime.datetime、datetime.timedelta,datetime.timezone等。 ----------动动小手,非常感谢各位的点赞收藏和关注。----------- 使用datet

lua data time

local getTime = os.date(“%c”); 其中的%c可以是以下的一种:(注意大小写) %a abbreviated weekday name (e.g., Wed) %A full weekday name (e.g., Wednesday) %b abbreviated month name (e.g., Sep) %B full month name (e.g., Sep

Event Time源码分析

《2021年最新版大数据面试题全面开启更新》 flink 中Processing Time也就是处理时间在watermark定时生成、ProcessFunction中定时器与时间类型的窗口中都有使用,但是其内部是如何实现注册定时器、如何调用、如何容错保证在任务挂掉在下次重启仍然能够触发任务执行,都是我们今天的主题。首先需要了解一下在flink内部时间系统是由哪些类来共同完成这件事,下面画

大数据-121 - Flink Time Watermark 详解 附带示例详解

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完)Hive(已更完)Flume(已更完)Sqoop(已更完)Zookeeper(已更完)HBase(已更完)Redis (已更完)Kafka(已更完)Spark(已更完)Flink(正在更新!) 章节内容 上节我们完成了如下的内容: 滑动窗口:时间驱动、事件

COD论文笔记 Adaptive Guidance Learning for Camouflaged Object Detection

论文的主要动机、现有方法的不足、拟解决的问题、主要贡献和创新点如下: 动机: 论文的核心动机是解决伪装目标检测(COD)中的挑战性任务。伪装目标检测旨在识别和分割那些在视觉上与周围环境高度相似的目标,这对于计算机视觉来说是非常困难的任务。尽管深度学习方法在该领域取得了一定进展,但现有方法仍面临有效分离目标和背景的难题,尤其是在伪装目标与背景特征高度相似的情况下。 现有方法的不足之处: 过于

AUTOSAR Adaptive与智能汽车E/E架构发展趋势

AUTOSAR Adaptive是一个面向现代汽车应用需求的标准,特别适用于那些需要高计算能力和灵活性的应用。以下是AUTOSAR Adaptive的典型特性: 高计算能力:AUTOSAR Adaptive支持使用MPU(微处理器),这些处理器的性能与PC或智能手机中的处理器相当。这样的高计算能力是实现半自动驾驶和其他复杂功能所必需的。动态更新和管理:AUTOSAR Adaptive的架构允

DS简记1-Real-time Joint Object Detection and Semantic Segmentation Network for Automated Driving

创新点 1.更小的网络,更多的类别,更复杂的实验 2. 一体化 总结 终于看到一篇检测跟踪一体化的文章 网络结构如下: ResNet10是共享的Encoder,yolov2 是检测的Deconder,FCN8 是分割的Deconder。 其实很简单,论文作者也指出:Our work is closest to the recent MultiNet. We differ by focus

Go-Time

日期&时间格式化。 package mainimport ("fmt""time")func main() {now := time.Now()now_string := fmt.Sprintf("%d%02d%02d-%02d%02d%02d-Others",now.Year(), now.Month(), now.Day(),now.Hour(), now.Minute(), now.Se