PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题

2023-11-30 13:36

本文主要是介绍PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134709211

IMG

使用 ESMFold 推理长序列 (Seq. Len. > 1500) 时,导致显存不足,需要设置 chunk_size 参数,实现长序列蛋白质的结构预测,避免显存溢出。

ESMFold:https://github.com/facebookresearch/esm

测试 ESM 单条 Case,序列长度 1543 较长,即:

python -u myscripts/esmfold_infer.py \
-f fasta_446/7WY5_R1543.fasta \
-o mydata/test_gpcr/

A100 显存溢出:

Tried to allocate 54.74 GiB (GPU 0; 79.32 GiB total capacity; 73.53 GiB already allocated; 3.94 GiB free; 74.24 GiB reserved in total by PyTorch)

解决显存问题,参考:Out of memory - upper limit on sequence length?

关键参数:chunk-size

Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). Equivalent to running a for loop over chunks of of each dimension. Lower values will result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. Default: None.将轴向注意力计算分块 (Chunks) ,将内存使用量从 O(L^2) 减少到 O(L)。 相当于在每个维度的块上运行 for 循环。 较低的值将导致内存使用量降低,但代价是速度。 建议值:128、64、32。默认值:无。

关键参数:max-tokens-per-batch,即 max_tokens_per_batch

Maximum number of tokens per gpu forward-pass. This will group shorter sequences together for batched prediction. Lowering this can help with out of memory issues, if these occur on short sequences.每个 GPU 前向传递的最大令牌数。 这会将较短的序列分组在一起以进行批量预测。 如果内存不足问题发生在短序列上,降低此值可以帮助解决这些问题。

chunk-size 设置成 128,问题解决,即:

max_len = 1200
# A100 最多支持 1200 长度的序列
if len(seq) > max_len:chunk_size = 128print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")self.model.set_chunk_size(chunk_size)
else:self.model.set_chunk_size(None)with torch.no_grad():output = self.model.infer_pdb(seq)

推理脚本:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/7/5
"""
import argparse
import os
import sys
import time
from pathlib import Pathimport torch
from tqdm import tqdmimport esmp = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:sys.path.append(p)from myutils.protein_utils import get_seq_from_fasta
from myutils.project_utils import time_elapsed, mkdir_if_not_exist, traverse_dir_filesclass EsmfoldInfer(object):"""ESMFold的推理类"""def __init__(self):print("[Info] 开始加载 ESMFold 模型!")s_time = time.time()model = esm.pretrained.esmfold_v1()self.model = model.eval().cuda()print(f"[Info] vocab: {self.model.esm_dict.to_dict()}")# 耗时: 00:01:13.264272print(f"[Info] 完成加载 ESMFold 模型! 耗时: {time_elapsed(s_time, time.time())}")def predict_seq(self, seq, out_path, is_log=True):"""预测序列"""print(f"[Info] seq_len: {len(seq)}")max_len = 1200# A100 最多支持 1200 长度的序列if len(seq) > max_len:chunk_size = 128print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")self.model.set_chunk_size(chunk_size)else:self.model.set_chunk_size(None)s_time = time.time()with torch.no_grad():output = self.model.infer_pdb(seq)seq_len = len(seq)if is_log:print(f"[Info] 完成推理,链长 {seq_len}, 耗时: {time_elapsed(s_time, time.time())}, "f"平均序列耗时: {(time.time() - s_time) / seq_len}")with open(out_path, "w") as f:f.write(output)if is_log:print(f"[Info] 输出: {output}")def predict_fasta_dir(self, input_path, output_dir):"""预测 FASTA 文件夹"""print(f"[Info] input_path: {input_path}")print(f"[Info] output_dir: {output_dir}")assert os.path.isfile(input_path) or os.path.isdir(input_path)mkdir_if_not_exist(output_dir)if os.path.isdir(input_path):path_list = traverse_dir_files(input_path, ext="fasta")elif os.path.isfile(input_path):path_list = [input_path]else:raise Exception(f"Error input: {input_path}")print(f"[Info] Fasta 数量: {len(path_list)}")s_time = time.time()for path in tqdm(path_list, desc="[Info] fasta"):fasta_name = os.path.basename(path).split(".")[0]output_fasta_dir = os.path.join(output_dir, fasta_name)mkdir_if_not_exist(output_fasta_dir)pdb_name = os.path.basename(path).replace("fasta", "pdb")output_pdb_path = os.path.join(output_fasta_dir, pdb_name)if os.path.exists(output_pdb_path):print(f"[Info] 已预测完成: {output_pdb_path}")continueseqs, _ = get_seq_from_fasta(path)seq = seqs[0]self.predict_seq(seq, output_pdb_path, is_log=False)print(f"[Info] 全部运行完成: {output_dir}, 耗时: {time_elapsed(s_time, time.time())}")def main():parser = argparse.ArgumentParser()parser.add_argument("-f","--fasta-input",type=Path,required=True,)parser.add_argument("-o","--output-dir",type=Path,required=True)args = parser.parse_args()fasta_input = str(args.fasta_input)output_dir = str(args.output_dir)mkdir_if_not_exist(output_dir)ei = EsmfoldInfer()ei.predict_fasta_dir(fasta_input, output_dir)if __name__ == '__main__':main()

这篇关于PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/437222

相关文章

好题——hdu2522(小数问题:求1/n的第一个循环节)

好喜欢这题,第一次做小数问题,一开始真心没思路,然后参考了网上的一些资料。 知识点***********************************无限不循环小数即无理数,不能写作两整数之比*****************************(一开始没想到,小学没学好) 此题1/n肯定是一个有限循环小数,了解这些后就能做此题了。 按照除法的机制,用一个函数表示出来就可以了,代码如下

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

usaco 1.3 Mixing Milk (结构体排序 qsort) and hdu 2020(sort)

到了这题学会了结构体排序 于是回去修改了 1.2 milking cows 的算法~ 结构体排序核心: 1.结构体定义 struct Milk{int price;int milks;}milk[5000]; 2.自定义的比较函数,若返回值为正,qsort 函数判定a>b ;为负,a<b;为0,a==b; int milkcmp(const void *va,c

如何解决线上平台抽佣高 线下门店客流少的痛点!

目前,许多传统零售店铺正遭遇客源下降的难题。尽管广告推广能带来一定的客流,但其费用昂贵。鉴于此,众多零售商纷纷选择加入像美团、饿了么和抖音这样的大型在线平台,但这些平台的高佣金率导致了利润的大幅缩水。在这样的市场环境下,商家之间的合作网络逐渐成为一种有效的解决方案,通过资源和客户基础的共享,实现共同的利益增长。 以最近在上海兴起的一个跨行业合作平台为例,该平台融合了环保消费积分系统,在短

购买磨轮平衡机时应该注意什么问题和技巧

在购买磨轮平衡机时,您应该注意以下几个关键点: 平衡精度 平衡精度是衡量平衡机性能的核心指标,直接影响到不平衡量的检测与校准的准确性,从而决定磨轮的振动和噪声水平。高精度的平衡机能显著减少振动和噪声,提高磨削加工的精度。 转速范围 宽广的转速范围意味着平衡机能够处理更多种类的磨轮,适应不同的工作条件和规格要求。 振动监测能力 振动监测能力是评估平衡机性能的重要因素。通过传感器实时监

缓存雪崩问题

缓存雪崩是缓存中大量key失效后当高并发到来时导致大量请求到数据库,瞬间耗尽数据库资源,导致数据库无法使用。 解决方案: 1、使用锁进行控制 2、对同一类型信息的key设置不同的过期时间 3、缓存预热 1. 什么是缓存雪崩 缓存雪崩是指在短时间内,大量缓存数据同时失效,导致所有请求直接涌向数据库,瞬间增加数据库的负载压力,可能导致数据库性能下降甚至崩溃。这种情况往往发生在缓存中大量 k

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P

uva 10131 最长子序列

题意: 给大象的体重和智商,求体重按从大到小,智商从高到低的最长子序列,并输出路径。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#include <stack>#include <vect

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

自定义类型:结构体(续)

目录 一. 结构体的内存对齐 1.1 为什么存在内存对齐? 1.2 修改默认对齐数 二. 结构体传参 三. 结构体实现位段 一. 结构体的内存对齐 在前面的文章里我们已经讲过一部分的内存对齐的知识,并举出了两个例子,我们再举出两个例子继续说明: struct S3{double a;int b;char c;};int mian(){printf("%zd\n",s