PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题

2023-11-30 13:36

本文主要是介绍PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134709211

IMG

使用 ESMFold 推理长序列 (Seq. Len. > 1500) 时,导致显存不足,需要设置 chunk_size 参数,实现长序列蛋白质的结构预测,避免显存溢出。

ESMFold:https://github.com/facebookresearch/esm

测试 ESM 单条 Case,序列长度 1543 较长,即:

python -u myscripts/esmfold_infer.py \
-f fasta_446/7WY5_R1543.fasta \
-o mydata/test_gpcr/

A100 显存溢出:

Tried to allocate 54.74 GiB (GPU 0; 79.32 GiB total capacity; 73.53 GiB already allocated; 3.94 GiB free; 74.24 GiB reserved in total by PyTorch)

解决显存问题,参考:Out of memory - upper limit on sequence length?

关键参数:chunk-size

Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). Equivalent to running a for loop over chunks of of each dimension. Lower values will result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. Default: None.将轴向注意力计算分块 (Chunks) ,将内存使用量从 O(L^2) 减少到 O(L)。 相当于在每个维度的块上运行 for 循环。 较低的值将导致内存使用量降低,但代价是速度。 建议值:128、64、32。默认值:无。

关键参数:max-tokens-per-batch,即 max_tokens_per_batch

Maximum number of tokens per gpu forward-pass. This will group shorter sequences together for batched prediction. Lowering this can help with out of memory issues, if these occur on short sequences.每个 GPU 前向传递的最大令牌数。 这会将较短的序列分组在一起以进行批量预测。 如果内存不足问题发生在短序列上,降低此值可以帮助解决这些问题。

chunk-size 设置成 128,问题解决,即:

max_len = 1200
# A100 最多支持 1200 长度的序列
if len(seq) > max_len:chunk_size = 128print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")self.model.set_chunk_size(chunk_size)
else:self.model.set_chunk_size(None)with torch.no_grad():output = self.model.infer_pdb(seq)

推理脚本:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/7/5
"""
import argparse
import os
import sys
import time
from pathlib import Pathimport torch
from tqdm import tqdmimport esmp = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:sys.path.append(p)from myutils.protein_utils import get_seq_from_fasta
from myutils.project_utils import time_elapsed, mkdir_if_not_exist, traverse_dir_filesclass EsmfoldInfer(object):"""ESMFold的推理类"""def __init__(self):print("[Info] 开始加载 ESMFold 模型!")s_time = time.time()model = esm.pretrained.esmfold_v1()self.model = model.eval().cuda()print(f"[Info] vocab: {self.model.esm_dict.to_dict()}")# 耗时: 00:01:13.264272print(f"[Info] 完成加载 ESMFold 模型! 耗时: {time_elapsed(s_time, time.time())}")def predict_seq(self, seq, out_path, is_log=True):"""预测序列"""print(f"[Info] seq_len: {len(seq)}")max_len = 1200# A100 最多支持 1200 长度的序列if len(seq) > max_len:chunk_size = 128print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")self.model.set_chunk_size(chunk_size)else:self.model.set_chunk_size(None)s_time = time.time()with torch.no_grad():output = self.model.infer_pdb(seq)seq_len = len(seq)if is_log:print(f"[Info] 完成推理,链长 {seq_len}, 耗时: {time_elapsed(s_time, time.time())}, "f"平均序列耗时: {(time.time() - s_time) / seq_len}")with open(out_path, "w") as f:f.write(output)if is_log:print(f"[Info] 输出: {output}")def predict_fasta_dir(self, input_path, output_dir):"""预测 FASTA 文件夹"""print(f"[Info] input_path: {input_path}")print(f"[Info] output_dir: {output_dir}")assert os.path.isfile(input_path) or os.path.isdir(input_path)mkdir_if_not_exist(output_dir)if os.path.isdir(input_path):path_list = traverse_dir_files(input_path, ext="fasta")elif os.path.isfile(input_path):path_list = [input_path]else:raise Exception(f"Error input: {input_path}")print(f"[Info] Fasta 数量: {len(path_list)}")s_time = time.time()for path in tqdm(path_list, desc="[Info] fasta"):fasta_name = os.path.basename(path).split(".")[0]output_fasta_dir = os.path.join(output_dir, fasta_name)mkdir_if_not_exist(output_fasta_dir)pdb_name = os.path.basename(path).replace("fasta", "pdb")output_pdb_path = os.path.join(output_fasta_dir, pdb_name)if os.path.exists(output_pdb_path):print(f"[Info] 已预测完成: {output_pdb_path}")continueseqs, _ = get_seq_from_fasta(path)seq = seqs[0]self.predict_seq(seq, output_pdb_path, is_log=False)print(f"[Info] 全部运行完成: {output_dir}, 耗时: {time_elapsed(s_time, time.time())}")def main():parser = argparse.ArgumentParser()parser.add_argument("-f","--fasta-input",type=Path,required=True,)parser.add_argument("-o","--output-dir",type=Path,required=True)args = parser.parse_args()fasta_input = str(args.fasta_input)output_dir = str(args.output_dir)mkdir_if_not_exist(output_dir)ei = EsmfoldInfer()ei.predict_fasta_dir(fasta_input, output_dir)if __name__ == '__main__':main()

这篇关于PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/437222

相关文章

Java中读取YAML文件配置信息常见问题及解决方法

《Java中读取YAML文件配置信息常见问题及解决方法》:本文主要介绍Java中读取YAML文件配置信息常见问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要... 目录1 使用Spring Boot的@ConfigurationProperties2. 使用@Valu

SQL Server配置管理器无法打开的四种解决方法

《SQLServer配置管理器无法打开的四种解决方法》本文总结了SQLServer配置管理器无法打开的四种解决方法,文中通过图文示例介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录方法一:桌面图标进入方法二:运行窗口进入检查版本号对照表php方法三:查找文件路径方法四:检查 S

怎样通过分析GC日志来定位Java进程的内存问题

《怎样通过分析GC日志来定位Java进程的内存问题》:本文主要介绍怎样通过分析GC日志来定位Java进程的内存问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、GC 日志基础配置1. 启用详细 GC 日志2. 不同收集器的日志格式二、关键指标与分析维度1.

Java 线程安全与 volatile与单例模式问题及解决方案

《Java线程安全与volatile与单例模式问题及解决方案》文章主要讲解线程安全问题的五个成因(调度随机、变量修改、非原子操作、内存可见性、指令重排序)及解决方案,强调使用volatile关键字... 目录什么是线程安全线程安全问题的产生与解决方案线程的调度是随机的多个线程对同一个变量进行修改线程的修改操

Redis出现中文乱码的问题及解决

《Redis出现中文乱码的问题及解决》:本文主要介绍Redis出现中文乱码的问题及解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1. 问题的产生2China编程. 问题的解决redihttp://www.chinasem.cns数据进制问题的解决中文乱码问题解决总结

全面解析MySQL索引长度限制问题与解决方案

《全面解析MySQL索引长度限制问题与解决方案》MySQL对索引长度设限是为了保持高效的数据检索性能,这个限制不是MySQL的缺陷,而是数据库设计中的权衡结果,下面我们就来看看如何解决这一问题吧... 目录引言:为什么会有索引键长度问题?一、问题根源深度解析mysql索引长度限制原理实际场景示例二、五大解决

Springboot如何正确使用AOP问题

《Springboot如何正确使用AOP问题》:本文主要介绍Springboot如何正确使用AOP问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录​一、AOP概念二、切点表达式​execution表达式案例三、AOP通知四、springboot中使用AOP导出

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘问题

《解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘问题》:本文主要介绍解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4... 目录未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘打开pom.XM

XML重复查询一条Sql语句的解决方法

《XML重复查询一条Sql语句的解决方法》文章分析了XML重复查询与日志失效问题,指出因DTO缺少@Data注解导致日志无法格式化、空指针风险及参数穿透,进而引发性能灾难,解决方案为在Controll... 目录一、核心问题:从SQL重复执行到日志失效二、根因剖析:DTO断裂引发的级联故障三、解决方案:修复