PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134709211

IMG

使用 ESMFold 推理长序列 (Seq. Len. > 1500) 时,导致显存不足,需要设置 chunk_size 参数,实现长序列蛋白质的结构预测,避免显存溢出。

ESMFold:https://github.com/facebookresearch/esm

测试 ESM 单条 Case,序列长度 1543 较长,即:

python -u myscripts/esmfold_infer.py \
-f fasta_446/7WY5_R1543.fasta \
-o mydata/test_gpcr/

A100 显存溢出:

Tried to allocate 54.74 GiB (GPU 0; 79.32 GiB total capacity; 73.53 GiB already allocated; 3.94 GiB free; 74.24 GiB reserved in total by PyTorch)

解决显存问题,参考:Out of memory - upper limit on sequence length?

关键参数:chunk-size

Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). Equivalent to running a for loop over chunks of of each dimension. Lower values will result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. Default: None.

将轴向注意力计算分块 (Chunks) ,将内存使用量从 O(L^2) 减少到 O(L)。 相当于在每个维度的块上运行 for 循环。 较低的值将导致内存使用量降低,但代价是速度。 建议值:128、64、32。默认值:无。

关键参数:max-tokens-per-batch,即 max_tokens_per_batch

Maximum number of tokens per gpu forward-pass. This will group shorter sequences together for batched prediction. Lowering this can help with out of memory issues, if these occur on short sequences.

每个 GPU 前向传递的最大令牌数。 这会将较短的序列分组在一起以进行批量预测。 如果内存不足问题发生在短序列上,降低此值可以帮助解决这些问题。

chunk-size 设置成 128,问题解决,即:

max_len = 1200
# A100 最多支持 1200 长度的序列
if len(seq) > max_len:
    chunk_size = 128
    print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")
    self.model.set_chunk_size(chunk_size)
else:
    self.model.set_chunk_size(None)
    
with torch.no_grad():
    output = self.model.infer_pdb(seq)

推理脚本:

扫描二维码关注公众号,回复: 17237949 查看本文章
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/7/5
"""
import argparse
import os
import sys
import time
from pathlib import Path

import torch
from tqdm import tqdm

import esm

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)


from myutils.protein_utils import get_seq_from_fasta
from myutils.project_utils import time_elapsed, mkdir_if_not_exist, traverse_dir_files


class EsmfoldInfer(object):
    """
    ESMFold的推理类
    """
    def __init__(self):
        print("[Info] 开始加载 ESMFold 模型!")
        s_time = time.time()
        model = esm.pretrained.esmfold_v1()
        self.model = model.eval().cuda()
        print(f"[Info] vocab: {self.model.esm_dict.to_dict()}")
        # 耗时: 00:01:13.264272
        print(f"[Info] 完成加载 ESMFold 模型! 耗时: {time_elapsed(s_time, time.time())}")

    def predict_seq(self, seq, out_path, is_log=True):
        """
        预测序列
        """
        print(f"[Info] seq_len: {len(seq)}")
        max_len = 1200
        # A100 最多支持 1200 长度的序列
        if len(seq) > max_len:
            chunk_size = 128
            print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")
            self.model.set_chunk_size(chunk_size)
        else:
    		self.model.set_chunk_size(None)

        s_time = time.time()
        with torch.no_grad():
            output = self.model.infer_pdb(seq)
        seq_len = len(seq)
        if is_log:
            print(f"[Info] 完成推理,链长 {seq_len}, 耗时: {time_elapsed(s_time, time.time())}, "
                  f"平均序列耗时: {(time.time() - s_time) / seq_len}")
        with open(out_path, "w") as f:
            f.write(output)
        if is_log:
            print(f"[Info] 输出: {output}")

    def predict_fasta_dir(self, input_path, output_dir):
        """
        预测 FASTA 文件夹
        """
        print(f"[Info] input_path: {input_path}")
        print(f"[Info] output_dir: {output_dir}")
        assert os.path.isfile(input_path) or os.path.isdir(input_path)
        mkdir_if_not_exist(output_dir)

        if os.path.isdir(input_path):
            path_list = traverse_dir_files(input_path, ext="fasta")
        elif os.path.isfile(input_path):
            path_list = [input_path]
        else:
            raise Exception(f"Error input: {input_path}")

        print(f"[Info] Fasta 数量: {len(path_list)}")
        s_time = time.time()
        for path in tqdm(path_list, desc="[Info] fasta"):
            fasta_name = os.path.basename(path).split(".")[0]
            output_fasta_dir = os.path.join(output_dir, fasta_name)
            mkdir_if_not_exist(output_fasta_dir)

            pdb_name = os.path.basename(path).replace("fasta", "pdb")
            output_pdb_path = os.path.join(output_fasta_dir, pdb_name)

            if os.path.exists(output_pdb_path):
                print(f"[Info] 已预测完成: {output_pdb_path}")
                continue
            seqs, _ = get_seq_from_fasta(path)
            seq = seqs[0]
            self.predict_seq(seq, output_pdb_path, is_log=False)
        print(f"[Info] 全部运行完成: {output_dir}, 耗时: {time_elapsed(s_time, time.time())}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f",
        "--fasta-input",
        type=Path,
        required=True,
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        type=Path,
        required=True
    )
    args = parser.parse_args()

    fasta_input = str(args.fasta_input)
    output_dir = str(args.output_dir)
    mkdir_if_not_exist(output_dir)

    ei = EsmfoldInfer()
    ei.predict_fasta_dir(fasta_input, output_dir)


if __name__ == '__main__':
    main()

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/134709211
psp
今日推荐