Molecule Attention Transformer(一)

2024-01-13 21:40

本文主要是介绍Molecule Attention Transformer(一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

应用Transformer框架对分子属性进行预测,代码:MAT,原文:Molecule Attention Transformer。变量名,函数名很多来自The Annotated Transformer,在《深入浅出Embedding》一书中也做了讲解。本文主要从实例运行开始一步步看代码具体内容,整体模型如下:


请添加图片描述

文章目录

  • 1.数据准备
    • 1.1.load_data_from_df
      • 1.1.1.load_data_from_smiles
      • 1.1.2.featurize_mol
      • 1.1.3.get_atom_features
      • 1.1.4.one_hot_vector
    • 1.2.construct_loader
      • 1.2.1.construct_dataset
      • 1.2.2.Molecule
      • 1.2.3.MolDataset
      • 1.2.4.mol_collate_func
      • 1.2.5.pad_array
    • 1.3.summary

1.数据准备

from featurization.data_utils import load_data_from_df, construct_loader
batch_size = 64# Formal charges are one-hot encoded to keep compatibility with the pre-trained weights.
# If you do not plan to use the pre-trained weights, we recommend to set one_hot_formal_charge to False.
X, y = load_data_from_df('../data/freesolv/freesolv.csv', one_hot_formal_charge=True)
data_loader = construct_loader(X, y, batch_size)
  • 利用 load_data_from_df 读入原始数据,再用 construct_loader 将数据转化为 torch.utils.data.DataLoader 对象

1.1.load_data_from_df

def load_data_from_df(dataset_path, add_dummy_node=True, one_hot_formal_charge=False, use_data_saving=True):"""Load and featurize data stored in a CSV file.Args:dataset_path (str): A path to the CSV file containing the data. It should have two columns:the first one contains SMILES strings of the compounds,the second one contains labels.add_dummy_node (bool): If True, a dummy node will be added to the molecular graph. Defaults to True.one_hot_formal_charge (bool): If True, formal charges on atoms are one-hot encoded. Defaults to False.use_data_saving (bool): If True, saved features will be loaded from the dataset directory; if no feature fileis present, the features will be saved after calculations. Defaults to True.Returns:A tuple (X, y) in which X is a list of graph descriptors (node features, adjacency matrices, distance matrices),and y is a list of the corresponding labels."""feat_stamp = f'{"_dn" if add_dummy_node else ""}{"_ohfc" if one_hot_formal_charge else ""}'feature_path = dataset_path.replace('.csv', f'{feat_stamp}.p')if use_data_saving and os.path.exists(feature_path):logging.info(f"Loading features stored at '{feature_path}'")x_all, y_all = pickle.load(open(feature_path, "rb"))return x_all, y_alldata_df = pd.read_csv(dataset_path)data_x = data_df.iloc[:, 0].valuesdata_y = data_df.iloc[:, 1].valuesif data_y.dtype == np.float64:data_y = data_y.astype(np.float32)x_all, y_all = load_data_from_smiles(data_x, data_y, add_dummy_node=add_dummy_node,one_hot_formal_charge=one_hot_formal_charge)if use_data_saving and not os.path.exists(feature_path):logging.info(f"Saving features at '{feature_path}'")pickle.dump((x_all, y_all), open(feature_path, "wb"))return x_all, y_all
  • feature_path 主要是判断是否利用已经保存的数据,可以跳过
  • data_x 是 smiles 的序列数据,data_f 是标量数值,示例如下:
data_xdata_y
0CN©C(=O)c1ccc(cc1)OC-1.874467
1CS(=O)(=O)Cl-0.277514
2CC©C=C1.465089
3CCc1cnccn1-0.428367
4CCCCCCCO-0.105855
  • load_data_from_smiles 将 data_x 的 smiles 数据处理成 graph descriptors (node features, adjacency matrices, distance matrices),data_y 不变
  • 得到特征 x_all 和 y_all 后返回,示例如下:
import numpy as np
np.asarray(X).shape,np.asarray(y).shape #((642, 3), (642, 1))
X[0],y[0]
"""
([array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.]]),array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]]),array([[1.00000000e+06, 1.00000000e+06, 1.00000000e+06, 1.00000000e+06,1.00000000e+06, 1.00000000e+06, 1.00000000e+06, 1.00000000e+06,1.00000000e+06, 1.00000000e+06, 1.00000000e+06, 1.00000000e+06,1.00000000e+06, 1.00000000e+06],...[-1.8744674])"""
X[0][0].shape,X[0][1].shape,X[0][2].shape #((14, 28), (14, 14), (14, 14))
X[1][0].shape,X[1][1].shape,X[1][2].shape #((6, 28), (6, 6), (6, 6))
  • 每个分子原子数不同导致维度不一致,这里没有统一。每个原子用28维特征表示,可以在 featurize_mol 看出

1.1.1.load_data_from_smiles

def load_data_from_smiles(x_smiles, labels, add_dummy_node=True, one_hot_formal_charge=False):"""Load and featurize data from lists of SMILES strings and labels.Args:x_smiles (list[str]): A list of SMILES strings.labels (list[float]): A list of the corresponding labels.add_dummy_node (bool): If True, a dummy node will be added to the molecular graph. Defaults to True.one_hot_formal_charge (bool): If True, formal charges on atoms are one-hot encoded. Defaults to False.Returns:A tuple (X, y) in which X is a list of graph descriptors (node features, adjacency matrices, distance matrices),and y is a list of the corresponding labels."""x_all, y_all = [], []for smiles, label in zip(x_smiles, labels):try:mol = MolFromSmiles(smiles)try:mol = Chem.AddHs(mol)AllChem.EmbedMolecule(mol, maxAttempts=5000)AllChem.UFFOptimizeMolecule(mol)mol = Chem.RemoveHs(mol)except:AllChem.Compute2DCoords(mol)afm, adj, dist = featurize_mol(mol, add_dummy_node, one_hot_formal_charge)x_all.append([afm, adj, dist])y_all.append([label])except ValueError as e:logging.warning('the SMILES ({}) can not be converted to a graph.\nREASON: {}'.format(smiles, e))return x_all, y_all
  • 先产生 mol 的3D构象,利用UFF力场优化,这是为了计算原子间距离,具体计算在 featurize_mol

1.1.2.featurize_mol

def featurize_mol(mol, add_dummy_node, one_hot_formal_charge):"""Featurize molecule.Args:mol (rdchem.Mol): An RDKit Mol object.add_dummy_node (bool): If True, a dummy node will be added to the molecular graph.one_hot_formal_charge (bool): If True, formal charges on atoms are one-hot encoded.Returns:A tuple of molecular graph descriptors (node features, adjacency matrix, distance matrix)."""node_features = np.array([get_atom_features(atom, one_hot_formal_charge)for atom in mol.GetAtoms()])adj_matrix = np.eye(mol.GetNumAtoms())for bond in mol.GetBonds():begin_atom = bond.GetBeginAtom().GetIdx()end_atom = bond.GetEndAtom().GetIdx()adj_matrix[begin_atom, end_atom] = adj_matrix[end_atom, begin_atom] = 1conf = mol.GetConformer()pos_matrix = np.array([[conf.GetAtomPosition(k).x, conf.GetAtomPosition(k).y, conf.GetAtomPosition(k).z]for k in range(mol.GetNumAtoms())])dist_matrix = pairwise_distances(pos_matrix)if add_dummy_node:m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1))m[1:, 1:] = node_featuresm[0, 0] = 1.node_features = mm = np.zeros((adj_matrix.shape[0] + 1, adj_matrix.shape[1] + 1))m[1:, 1:] = adj_matrixadj_matrix = mm = np.full((dist_matrix.shape[0] + 1, dist_matrix.shape[1] + 1), 1e6)m[1:, 1:] = dist_matrixdist_matrix = mreturn node_features, adj_matrix, dist_matrix
  • node_features 主要用 get_atom_features 得到
  • 邻接矩阵 adj_matrix 原子相连为1,不连为0
  • 距离矩阵 dist_matrix 主要用 pairwise_distances 得到,mol 传入前已处理,GetConformer获取原子坐标信息,pos_matrix 是 n × 3 n \times 3 n×3维矩阵,dist_matrix 得到的是 n × n n \times n n×n维对称矩阵
  • add_dummy_node 默认是True,dummy_node 不与分子中的任何原子相连,它与其他原子的距离设为了 1 0 6 10^6 106,这样模型可以在什么 pattern 都找不到的时候跳过搜索,类似 BERT 中的 [SEP] 词元(原文中提到)。添加 dummy_node 后,node_feature 在第一个编码,邻接矩阵对应为0,距离矩阵对应设为1e6
pos_matrix=np.array([[1,1,1],[1,2,3]
])
print(pairwise_distances(pos_matrix))
"""
[[0.         2.23606798][2.23606798 0.        ]]
"""
print(np.sqrt((1-1)**2+(1-2)**2+(1-3)**2)) #2.23606797749979

1.1.3.get_atom_features

def get_atom_features(atom, one_hot_formal_charge=True):"""Calculate atom features.Args:atom (rdchem.Atom): An RDKit Atom object.one_hot_formal_charge (bool): If True, formal charges on atoms are one-hot encoded.Returns:A 1-dimensional array (ndarray) of atom features."""attributes = []attributes += one_hot_vector(atom.GetAtomicNum(),[5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999])attributes += one_hot_vector(len(atom.GetNeighbors()),[0, 1, 2, 3, 4, 5])attributes += one_hot_vector(atom.GetTotalNumHs(),[0, 1, 2, 3, 4])if one_hot_formal_charge:attributes += one_hot_vector(atom.GetFormalCharge(),[-1, 0, 1])else:attributes.append(atom.GetFormalCharge())attributes.append(atom.IsInRing())attributes.append(atom.GetIsAromatic())return np.array(attributes, dtype=np.float32)
  • 每个原子的特征以28维 one-hot 表示,可以从原文中了解,0-11依据原子序数编码有机物常见原子(包括了dummy_node),12-17编码邻位原子的数目,18-22编码氢原子数,23-25编码原子电荷(one_hot_formal_charge为True),26编码原子是否位于环上,27编码是否是芳香性原子,此函数返回的实际是 27 维 one-hot,因为第一步实际没有编码 dummy_node,真正编码在 featurize_mol,下面的图可能有误导,实际 dummy_node 编码在第一列,而不是倒数第二列

在这里插入图片描述

1.1.4.one_hot_vector

def one_hot_vector(val, lst):"""Converts a value to a one-hot vector based on options in lst"""if val not in lst:val = lst[-1]return map(lambda x: x == val, lst)
  • 依据 lst 中的内容进行编码,如果不存在就以最后一位编码,例如不是有机物常见原子以[0,0,…1]表示

1.2.construct_loader

def construct_loader(x, y, batch_size, shuffle=True):"""Construct a data loader for the provided data.Args:x (list): A list of molecule features.y (list): A list of the corresponding labels.batch_size (int): The batch size.shuffle (bool): If True the data will be loaded in a random order. Defaults to True.Returns:A DataLoader object that yields batches of padded molecule features."""data_set = construct_dataset(x, y)loader = torch.utils.data.DataLoader(dataset=data_set,batch_size=batch_size,collate_fn=mol_collate_func,shuffle=shuffle)return loader
  • 先构造 dataset,定义处理函数 mol_collate_func,传入 DataLoader 返回 loader 对象

1.2.1.construct_dataset

def construct_dataset(x_all, y_all):"""Construct a MolDataset object from the provided data.Args:x_all (list): A list of molecule features.y_all (list): A list of the corresponding labels.Returns:A MolDataset object filled with the provided data."""output = [Molecule(data[0], data[1], i)for i, data in enumerate(zip(x_all, y_all))]return MolDataset(output)
  • 构造Molecule 对象列表,再构造 MolDataset 类,Molecule 对象接收 数据索引,x,y

1.2.2.Molecule

class Molecule:"""Class that represents a train/validation/test datum- self.label: 0 neg, 1 pos -1 missing for different target."""def __init__(self, x, y, index):self.node_features = x[0]self.adjacency_matrix = x[1]self.distance_matrix = x[2]self.y = yself.index = index
  • 将 x,y,index 数据整合在一起的类

1.2.3.MolDataset

class MolDataset(Dataset):"""Class that represents a train/validation/test dataset that's readable for PyTorchNote that this class inherits torch.utils.data.Dataset"""def __init__(self, data_list):"""@param data_list: list of Molecule objects"""self.data_list = data_listdef __len__(self):return len(self.data_list)def __getitem__(self, key):if type(key) == slice:return MolDataset(self.data_list[key])return self.data_list[key]
  • 继承 torch.utils.data.Dataset,需要实现这里列出的三个方法

1.2.4.mol_collate_func

def mol_collate_func(batch):"""Create a padded batch of molecule features.Args:batch (list[Molecule]): A batch of raw molecules.Returns:A list of FloatTensors with padded molecule features:adjacency matrices, node features, distance matrices, and labels."""adjacency_list, distance_list, features_list = [], [], []labels = []max_size = 0for molecule in batch:if type(molecule.y[0]) == np.ndarray:labels.append(molecule.y[0])else:labels.append(molecule.y)if molecule.adjacency_matrix.shape[0] > max_size:max_size = molecule.adjacency_matrix.shape[0]for molecule in batch:adjacency_list.append(pad_array(molecule.adjacency_matrix, (max_size, max_size)))distance_list.append(pad_array(molecule.distance_matrix, (max_size, max_size)))features_list.append(pad_array(molecule.node_features, (max_size, molecule.node_features.shape[1])))return [FloatTensor(features) for features in (adjacency_list, features_list, distance_list, labels)]
  • 第一个 for 循环得到 batch 中分子最多原子数和 labels 的列表,以 max_size 为基准 padding
  • 第二个 for 循环对 x 的三个数据矩阵 padding,pad_array 参数是数据矩阵和矩阵维度限定

1.2.5.pad_array

def pad_array(array, shape, dtype=np.float32):"""Pad a 2-dimensional array with zeros.Args:array (ndarray): A 2-dimensional array to be padded.shape (tuple[int]): The desired shape of the padded array.dtype (data-type): The desired data-type for the array.Returns:A 2-dimensional array of the given shape padded with zeros."""padded_array = np.zeros(shape, dtype=dtype)padded_array[:array.shape[0], :array.shape[1]] = arrayreturn padded_array
  • 在规定维度之外的部分补0

1.3.summary

  • 数据准备阶段输出 dataloader 对象,每次迭代一个 batch。经过了 mol_collate_func 的处理,不同 batch 原子数并没有统一,只有一个 batch 内原子数才恒定。data[0] 是邻接矩阵,data[1] 是 node_features,data[2] 是 距离矩阵
batch_size=2
cnt=1
for data in data_loader:print(data[0].shape)print(data[1].shape)print(data[2].shape)print(data[3].shape)cnt+=1if (cnt==3):break
"""
torch.Size([2, 13, 13])
torch.Size([2, 13, 28])
torch.Size([2, 13, 13])
torch.Size([2, 1])
torch.Size([2, 9, 9])
torch.Size([2, 9, 28])
torch.Size([2, 9, 9])
torch.Size([2, 1])
"""

这篇关于Molecule Attention Transformer(一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/602886

相关文章

什么是 Flash Attention

Flash Attention 是 由 Tri Dao 和 Dan Fu 等人在2022年的论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 中 提出的, 论文可以从 https://arxiv.org/abs/2205.14135 页面下载,点击 View PDF 就可以下载。 下面我

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04

Transformer从零详细解读

Transformer从零详细解读 一、从全局角度概况Transformer ​ 我们把TRM想象为一个黑盒,我们的任务是一个翻译任务,那么我们的输入是中文的“我爱你”,输入经过TRM得到的结果为英文的“I LOVE YOU” ​ 接下来我们对TRM进行细化,我们将TRM分为两个部分,分别为Encoders(编码器)和Decoders(解码器) ​ 在此基础上我们再进一步细化TRM的

LLM模型:代码讲解Transformer运行原理

视频讲解、获取源码:LLM模型:代码讲解Transformer运行原理(1)_哔哩哔哩_bilibili 1 训练保存模型文件 2 模型推理 3 推理代码 import torchimport tiktokenfrom wutenglan_model import WutenglanModelimport pyttsx3# 设置设备为CUDA(如果可用),否则使用CPU#

逐行讲解Transformer的代码实现和原理讲解:计算交叉熵损失

LLM模型:Transformer代码实现和原理讲解:前馈神经网络_哔哩哔哩_bilibili 1 计算交叉熵目的 计算 loss = F.cross_entropy(input=linear_predictions_reshaped, target=targets_reshaped) 的目的是为了评估模型预测结果与实际标签之间的差距,并提供一个量化指标,用于指导模型的训练过程。具体来说,交叉

时序预测|变分模态分解-双向时域卷积-双向门控单元-注意力机制多变量时间序列预测VMD-BiTCN-BiGRU-Attention

时序预测|变分模态分解-双向时域卷积-双向门控单元-注意力机制多变量时间序列预测VMD-BiTCN-BiGRU-Attention 文章目录 一、基本原理1. 变分模态分解(VMD)2. 双向时域卷积(BiTCN)3. 双向门控单元(BiGRU)4. 注意力机制(Attention)总结流程 二、实验结果三、核心代码四、代码获取五、总结 时序预测|变分模态分解-双向时域卷积

深度学习每周学习总结N9:transformer复现

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 目录 多头注意力机制前馈传播位置编码编码层解码层Transformer模型构建使用示例 本文为TR3学习打卡,为了保证记录顺序我这里写为N9 总结: 之前有学习过文本预处理的环节,对文本处理的主要方式有以下三种: 1:词袋模型(one-hot编码) 2:TF-I

RNN发展(RNN/LSTM/GRU/GNMT/transformer/RWKV)

RNN到GRU参考: https://blog.csdn.net/weixin_36378508/article/details/115101779 tRANSFORMERS参考: seq2seq到attention到transformer理解 GNMT 2016年9月 谷歌,基于神经网络的翻译系统(GNMT),并宣称GNMT在多个主要语言对的翻译中将翻译误差降低了55%-85%以上, G

ModuleNotFoundError: No module named ‘diffusers.models.dual_transformer_2d‘解决方法

Python应用运行报错,部分错误信息如下: Traceback (most recent call last): File “\pipelines_ootd\unet_vton_2d_blocks.py”, line 29, in from diffusers.models.dual_transformer_2d import DualTransformer2DModel ModuleNotF

阅读笔记--Guiding Attention in End-to-End Driving Models

作者:Diego Porres1, Yi Xiao1, Gabriel Villalonga1, Alexandre Levy1, Antonio M. L ́ opez1,2 出版时间:arXiv:2405.00242v1 [cs.CV] 30 Apr 2024 这篇论文研究了如何引导基于视觉的端到端自动驾驶模型的注意力,以提高它们的驾驶质量和获得更直观的激活图。 摘 要   介绍