机器学习笔记 - EANet 外部注意论文简读及代码实现

2023-10-14 16:20

本文主要是介绍机器学习笔记 - EANet 外部注意论文简读及代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、论文简述

 

        论文作者提出了一种新的轻量级注意力机制,称之为外部注意力。如图所示,计算自注意力需要首先通过计算自查询向量和自关键字向量之间的仿射关系来计算注意力图,然后通过用该注意力图加权自值向量来生成新的特征图。外部关注的作用不同。我们首先通过计算自查询向量和外部可学习密钥存储器之间的亲和力来计算注意力图,然后通过将该注意力图乘以另一个外部可学习值存储器来生成细化的特征图。

         在实践中,这两个存储器是用线性层实现的,因此可以通过端到端的反向传播来优化。它们独立于单个样本,并在整个数据集中共享,这起到了很强的正则化作用,提高了注意力机制的泛化能力。外部注意力的轻量级本质的关键在于,存储器中的元素数量远小于输入特征中的元素数目,从而产生输入中元素数目线性的计算复杂度。外部存储器旨在学习整个数据集中最具鉴别力的特征,捕捉信息量最大的部分,并排除其他样本中的干扰信息。类似的想法可以在稀疏编码[9]或字典学习中找到。然而,与这些方法不同的是,我们既没有尝试重建输入特征,也没有对注意力图应用任何显式稀疏正则化。

        尽管所提出的外部注意力方法很简单,但它对各种视觉任务都是有效的。由于其简单性,它可以很容易地融入现有流行的基于自注意的架构中,如DANet、SAGAN和T2T Transformer。图3展示了一个典型的架构,该架构将图像语义分割任务的自我注意力替换为外部注意力。我们使用不同的输入模式(图像和点云),对分类、对象检测、语义分割、实例分割和生成等基本视觉任务进行了广泛的实验。结果表明,我们的方法获得的结果与原始的自注意机制及其一些变体相当或更好,以低得多的计算成本。

使用我们提出的外部注意力进行语义分割的EANet架构。

二、相关参考代码

1、基于torch的实现

        外部注意力实现

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ExternalAttention(nn.Module):def __init__(self, d_model,S=64):super().__init__()self.mk=nn.Linear(d_model,S,bias=False)self.mv=nn.Linear(S,d_model,bias=False)self.softmax=nn.Softmax(dim=1)self.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, queries):attn=self.mk(queries) #bs,n,Sattn=self.softmax(attn) #bs,n,Sattn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,Sout=self.mv(attn) #bs,n,d_modelreturn outif __name__ == '__main__':input=torch.randn(50,49,512)ea = ExternalAttention(d_model=512,S=8)output=ea(input)print(output.shape)

        调用参考

input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)

2、基于tensorflow的EANet实现

        EANet只是用外部注意力代替了Vit中的自我注意力。

        这里的数据集是采用cifar100数据集,首先加载划分数据集,然后配置超参数

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as pltnum_classes = 100
input_shape = (32, 32, 3)(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")weight_decay = 0.0001
learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
patch_size = 2  # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2  # Number of patch
embedding_dim = 64  # Number of hidden units.
mlp_dim = 64
dim_coefficient = 4
num_heads = 4
attention_dropout = 0.2
projection_dropout = 0.2
num_transformer_blocks = 8  # Number of repetitions of the transformer layerprint(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")

        进行数据增强

data_augmentation = keras.Sequential([layers.Normalization(),layers.RandomFlip("horizontal"),layers.RandomRotation(factor=0.1),layers.RandomContrast(factor=0.1),layers.RandomZoom(height_factor=0.2, width_factor=0.2),],name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

        实现补丁提取和编码层‘

class PatchExtract(layers.Layer):def __init__(self, patch_size, **kwargs):super().__init__(**kwargs)self.patch_size = patch_sizedef call(self, images):batch_size = tf.shape(images)[0]patches = tf.image.extract_patches(images=images,sizes=(1, self.patch_size, self.patch_size, 1),strides=(1, self.patch_size, self.patch_size, 1),rates=(1, 1, 1, 1),padding="VALID",)patch_dim = patches.shape[-1]patch_num = patches.shape[1]return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))class PatchEmbedding(layers.Layer):def __init__(self, num_patch, embed_dim, **kwargs):super().__init__(**kwargs)self.num_patch = num_patchself.proj = layers.Dense(embed_dim)self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)def call(self, patch):pos = tf.range(start=0, limit=self.num_patch, delta=1)return self.proj(patch) + self.pos_embed(pos)

        实现外部注意力块

def external_attention(x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
):_, num_patch, channel = x.shapeassert dim % num_heads == 0num_heads = num_heads * dim_coefficientx = layers.Dense(dim * dim_coefficient)(x)# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]x = tf.reshape(x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads))x = tf.transpose(x, perm=[0, 2, 1, 3])# a linear layer M_kattn = layers.Dense(dim // dim_coefficient)(x)# normalize attention mapattn = layers.Softmax(axis=2)(attn)# dobule-normalizationattn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))attn = layers.Dropout(attention_dropout)(attn)# a linear layer M_vx = layers.Dense(dim * dim_coefficient // num_heads)(attn)x = tf.transpose(x, perm=[0, 2, 1, 3])x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])# a linear layer to project original dimx = layers.Dense(dim)(x)x = layers.Dropout(projection_dropout)(x)return x

        实施 MLP

def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)x = layers.Dropout(drop_rate)(x)x = layers.Dense(embedding_dim)(x)x = layers.Dropout(drop_rate)(x)return x

        实现变压器模块,基于参数配置选择外部注意或是自我关注。

def transformer_encoder(x,embedding_dim,mlp_dim,num_heads,dim_coefficient,attention_dropout,projection_dropout,attention_type="external_attention",
):residual_1 = xx = layers.LayerNormalization(epsilon=1e-5)(x)if attention_type == "external_attention":x = external_attention(x,embedding_dim,num_heads,dim_coefficient,attention_dropout,projection_dropout,)elif attention_type == "self_attention":x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout)(x, x)x = layers.add([x, residual_1])residual_2 = xx = layers.LayerNormalization(epsilon=1e-5)(x)x = mlp(x, embedding_dim, mlp_dim)x = layers.add([x, residual_2])return x

        实施 EANet 模型

def get_model(attention_type="external_attention"):inputs = layers.Input(shape=input_shape)# Image augmentx = data_augmentation(inputs)# Extract patches.x = PatchExtract(patch_size)(x)# Create patch embedding.x = PatchEmbedding(num_patches, embedding_dim)(x)# Create Transformer block.for _ in range(num_transformer_blocks):x = transformer_encoder(x,embedding_dim,mlp_dim,num_heads,dim_coefficient,attention_dropout,projection_dropout,attention_type,)x = layers.GlobalAvgPool1D()(x)outputs = layers.Dense(num_classes, activation="softmax")(x)model = keras.Model(inputs=inputs, outputs=outputs)return model

        进行训练并可视化

model = get_model(attention_type="external_attention")model.compile(loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),optimizer=tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay),metrics=[keras.metrics.CategoricalAccuracy(name="accuracy"),keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),],
)history = model.fit(x_train,y_train,batch_size=batch_size,epochs=num_epochs,validation_split=validation_split,
)model.save('eanet_cifar100.h5')

plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()

         进行验证

loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

三、相关参考

arxiv.org/pdf/2105.02358.pdfhttps://arxiv.org/pdf/2105.02358.pdfExternal-Attention-pytorch/ExternalAttention.py at master · xmu-xiaoma666/External-Attention-pytorch · GitHubhttps://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/ExternalAttention.py

这篇关于机器学习笔记 - EANet 外部注意论文简读及代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/211690

相关文章

SpringBoot3实现Gzip压缩优化的技术指南

《SpringBoot3实现Gzip压缩优化的技术指南》随着Web应用的用户量和数据量增加,网络带宽和页面加载速度逐渐成为瓶颈,为了减少数据传输量,提高用户体验,我们可以使用Gzip压缩HTTP响应,... 目录1、简述2、配置2.1 添加依赖2.2 配置 Gzip 压缩3、服务端应用4、前端应用4.1 N

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

Python FastAPI+Celery+RabbitMQ实现分布式图片水印处理系统

《PythonFastAPI+Celery+RabbitMQ实现分布式图片水印处理系统》这篇文章主要为大家详细介绍了PythonFastAPI如何结合Celery以及RabbitMQ实现简单的分布式... 实现思路FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理完整

springboot循环依赖问题案例代码及解决办法

《springboot循环依赖问题案例代码及解决办法》在SpringBoot中,如果两个或多个Bean之间存在循环依赖(即BeanA依赖BeanB,而BeanB又依赖BeanA),会导致Spring的... 目录1. 什么是循环依赖?2. 循环依赖的场景案例3. 解决循环依赖的常见方法方法 1:使用 @La

Java枚举类实现Key-Value映射的多种实现方式

《Java枚举类实现Key-Value映射的多种实现方式》在Java开发中,枚举(Enum)是一种特殊的类,本文将详细介绍Java枚举类实现key-value映射的多种方式,有需要的小伙伴可以根据需要... 目录前言一、基础实现方式1.1 为枚举添加属性和构造方法二、http://www.cppcns.co

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

使用C#代码在PDF文档中添加、删除和替换图片

《使用C#代码在PDF文档中添加、删除和替换图片》在当今数字化文档处理场景中,动态操作PDF文档中的图像已成为企业级应用开发的核心需求之一,本文将介绍如何在.NET平台使用C#代码在PDF文档中添加、... 目录引言用C#添加图片到PDF文档用C#删除PDF文档中的图片用C#替换PDF文档中的图片引言在当

C#使用SQLite进行大数据量高效处理的代码示例

《C#使用SQLite进行大数据量高效处理的代码示例》在软件开发中,高效处理大数据量是一个常见且具有挑战性的任务,SQLite因其零配置、嵌入式、跨平台的特性,成为许多开发者的首选数据库,本文将深入探... 目录前言准备工作数据实体核心技术批量插入:从乌龟到猎豹的蜕变分页查询:加载百万数据异步处理:拒绝界面

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

Java实现文件图片的预览和下载功能

《Java实现文件图片的预览和下载功能》这篇文章主要为大家详细介绍了如何使用Java实现文件图片的预览和下载功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... Java实现文件(图片)的预览和下载 @ApiOperation("访问文件") @GetMapping("