PSP - タンパク質配列抽出トランスフォーマータンパク質言語モデル ESM2 の機能

私の CSDN をフォローしてください: https://spike.blog.csdn.net/
この記事のアドレス: https://spike.blog.csdn.net/article/details/132888139

契約

タンパク質言語モデル ESM (Eevolutionary Scale Modeling) は、深層学習技術を使用してタンパク質の構造と機能を予測する手法です。ESM は、大規模なタンパク質配列データベース上で自己回帰ニューラル ネットワークをトレーニングすることにより、タンパク質の進化規則と配列-構造-機能の関係を学習します。ESM は、特定のタンパク質配列に基づいて、その構造と機能の特性を表す対応する潜在ベクトルを生成でき、また、潜在ベクトルを使用して、構造予測、機能アノテーション、相互作用解析などのさまざまな下流タスクを実行することもできます。ESM は、タンパク質科学に新しい視点とツールを提供する、強力で多用途なタンパク質言語モデルです。

ESM (進化的スケール モデリング)、つまり ESM-2、ESMFold、ESM-MSA-1b、ESM-1v、ESM-IF1 (逆折り畳み) を含む進化的スケール モデル、つまり

  • ESM-2、2022.8、SOTA 汎用プロテイン言語モデル v2 バージョン、ここで はESM-1vv1 バージョンです。
  • ESMFold、2022.11、エンドツーエンドの単一シーケンス 3D 構造予測
  • ESM-MSA-1b、2021.6、MSA トランスフォーマー言語モデル
  • ESM-IF1、2022.4、逆折モデル

具体的な参照先: ESM GitHub

1.Docker環境の構成

構成TORCH_HOMEBOS環境、つまり:

vim ~/.bashrc

export TORCH_HOME=[your folder]/torch_home/
alias bos='bcecmd/bcecmd --conf-path bcecmd/bceconf/ bos'

TORCH_HOME を構成し、PyTorch モデルのキャッシュ アドレス (つまり ) を修正することをお勧めしますtorch_home/hub/checkpoints

Docker イメージで、ESM 環境をインポートします。

conda create -n esmfold --clone miniconda3/envs/esmfold

インストールする必要があるトーチ関連パッケージ:

pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
pip install -q torch-cluster -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
pip install -q torch-geometric

Docker 環境をエクスポートします。

# 提交 Tag
docker ps -a
docker commit [container id] esmfold:v1.0

# 准备远程 Tag
docker tag esmfold:v1.0 [your ip]/esmfold:v1.0

# 推送至远程
docker push [your ip]/esmfold:v1.0
# 从远程拉取
# docker pull [your ip]/esmfold:v1.0

2. ESM2モデルのバッチ推論

ESM 推論スクリプトを構成します。

set -xe
PROJECT_DIR="$(cd "$(dirname $0)" && pwd)/.."

source activate esmfold
export PATH="/usr/local/cuda-11.6/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-11.6/lib64:$LD_LIBRARY_PATH"
export TORCH_HOME=[your folder]/torch_home/

echo "${PROJECT_DIR}"

python "${PROJECT_DIR}/scripts/extract.py" esm2_t36_3B_UR50D \
  "${PROJECT_DIR}/mydata/all-1.fasta" \
  [your folder]/esm2_3B_feat/ \
  --toks_per_batch 1536 \
  --repr_layers -1 \
  --include per_tok contacts \
  --truncation_seq_length 1536 \
  --num_workers 8

テスト済みの A100 グラフィックス カード 80G は、最大シーケンス長 1536 をサポートします。

scripts/extract.pyスクリプトを最適化します。出力結果は、長すぎるシーケンスや名前の繰り返しを避けるためのシーケンス MD5 エンコードの機能です。

  1. 増やすnum_workersと推論速度が向上します。
  2. labelタンパク質配列に置き換えます。
  3. 繰り返しの検索を避けるためにブレークポイント処理を追加する

今すぐ

# ...
data_loader = DataLoader(
    dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length),
    batch_sampler=batches, num_workers=args.num_workers,
)
# ...
# result = {"label": label}
result = {
    
    "label": strs[i]}  # label 修改成序列
# ...
for i, label in enumerate(labels):
    args.output_file = args.output_dir / f"{
      
      label}.pt"
    if os.path.isfile(args.output_file):
        warnings.warn(f"The feat has processed. {
      
      args.output_file}")
        continue
# ...

使用できないことに注意してくださいnum_workers。使用しないとプログラムが実行されません。

ログ:

python workspace/esm-by-chenlong/run/../scripts/extract.py esm2_t36_3B_UR50D workspace/esm-by-chenlong/mydata/all-1.fasta pdb_dataset/esm2_6b_feat/ --toks_per_batch 1536 --repr_layers -1 --include per_tok contacts --truncation_seq_length 1536 --num_workers 32
Transferred model to GPU
Read /nfs_beijing_ai/chenlong/workspace/esm-by-chenlong/run/../mydata/all-1.fasta with 27115 sequences
Processing 1 of 6668 batches (66 sequences)
Processing 2 of 6668 batches (61 sequences)
Processing 3 of 6668 batches (56 sequences)
Processing 4 of 6668 batches (52 sequences)
Processing 5 of 6668 batches (51 sequences)

シーケンス サイズ 2048 ではビデオ メモリがオーバーフローすることに注意してください。

3. FASTA データを入力するために ESM2 を準備する

FASTA フォルダー内のすべての FASTA ファイルは 1 つのファイルに結合され、同じシーケンスに対する特徴の繰り返し生成を避けるために、シーケンスの説明がハッシュ エンコードに変換されます。

  • seq_encoder: ハッシュエンコード機能。検索にも使用されます。
  • load_feat: フィーチャーの特性を読み取り、データの表示をサポートし、イメージを描画します。
  • merge_fasta_folder: FASTA フォルダーを結合します。

今すぐ:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/9/13
"""
import argparse
import os
import sys
import warnings
from pathlib import Path

from tqdm import tqdm

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)

from myutils.project_utils import traverse_dir_files_for_large, read_file, write_list_to_file


class Esm2FastaGenerator(object):
    """
    ESM2 工具类
    """
    def __init__(self):
        pass

    @staticmethod
    def seq_encoder(sequence):
        """
        将 seq 使用 hash 编码,避免重复生成
        """
        import hashlib
        return hashlib.md5(sequence.encode(encoding="utf-8")).hexdigest()

    @staticmethod
    def load_feat(path, is_print=False):
        """
        加载 ESM 特征文件,以及打印特征
        """
        import torch
        from torch import Tensor
        rep = torch.load(path)
        if is_print:
            print(f"[Info] rep: {
      
      rep.keys()}")
            for key in rep.keys():
                val = rep[key]
                if isinstance(val, str):
                    print(f"[Info] {
      
      key}: {
      
      val}")
                elif isinstance(val, dict):
                    for sub_key in val.keys():
                        print(f"[Info] {
      
      key}: {
      
      sub_key}: {
      
      val[sub_key].shape}")
                elif isinstance(val, Tensor):
                    print(f"[Info] {
      
      key}: {
      
      val.shape}")
                else:
                    print(f"[Info] {
      
      key}: {
      
      val}")

            # 绘制接触矩阵
            import matplotlib.pyplot as plt
            contacts_map = rep["contacts"]
            plt.matshow(contacts_map)
            plt.title("contacts_map")
            save_name = "contacts_map.png"
            plt.savefig(save_name, bbox_inches='tight', format='png')
            plt.show()
        return rep

    @classmethod
    def merge_fasta_folder(cls, folder_path, output_path):
        """
        合并 fasta 文件,用于 esm 推理
        """
        print(f"[Info] folder_path: {
      
      folder_path}")
        print(f"[Info] output_path: {
      
      output_path}")
        assert os.path.isdir(folder_path)
        path_list = traverse_dir_files_for_large(folder_path, ext="fasta")
        print(f"[Info] fasta: {
      
      len(path_list)}")
        seq_set = set()
        for path in tqdm(path_list, "[Info] fasta"):
            data_lines = read_file(path)
            n = len(data_lines)
            for i in range(1, n, 2):
                seq = data_lines[i]
                if seq:
                    seq_set.add(seq)
        seq_list = list(seq_set)
        print(f"[Info] seq unique: {
      
      len(seq_list)}")
        # create_empty_file(output_path)
        seq_lines = []
        header_set = set()
        for seq in tqdm(seq_list, "[Info] seq"):
            header = cls.seq_encoder(seq)
            header_set.add(header)
            seq_lines.append(f">{
      
      header}")
            seq_lines.append(seq)
        assert len(seq_lines) // 2 == len(header_set)
        write_list_to_file(output_path, seq_lines)
        print(f"[Info] over! {
      
      output_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f",
        "--folder-path",
        type=Path,
        required=True,
    )
    parser.add_argument(
        "-o",
        "--output-path",
        type=Path,
        required=True
    )
    args = parser.parse_args()

    folder_path = str(args.folder_path)
    output_path = str(args.output_path)
    if os.path.isfile(output_path):
        warnings.warn(f"The output file exists, append lines to it! {
      
      output_path}")
    # from root_dir import DATA_DIR
    # folder_path = os.path.join(DATA_DIR, "fasta")
    # output_path = os.path.join(DATA_DIR, "all.fasta")
    Esm2FastaGenerator.merge_fasta_folder(folder_path, output_path)


def main2():
    from root_dir import DATA_DIR
    feat_path = os.path.join(DATA_DIR, "fffd26f4307d76eec938ac9c2c93a698.pt")
    Esm2FastaGenerator.load_feat(feat_path, is_print=True)


if __name__ == '__main__':
    main()
    # main2()

出力シーケンス ESM2 の機能は次のとおりです。

  • labelシーケンスの説明
  • representations配列特性評価 LxH
  • mean_representations平均表現 H
  • bos_representations開始トークンは H を表します
  • contactsシーケンス接点の特性評価 LxL

たとえば、シーケンスの長さは 65、ESM2 650M のエンベディング数は 1280、ESM2 3B のエンベディング数は 2560 です。つまり、次のようになります。

[Info] rep: dict_keys(['label', 'representations', 'contacts'])
[Info] label: MAKDSKAPVVEIFDERDGCTSAGSTGKASDAGEKGLLVKVSMQKVGYNAIMAKSVAASYMNK
[Info] representations: 36: torch.Size([62, 2560])
[Info] contacts: torch.Size([62, 62])

このうち、シーケンス長 235 の ESM2 3B フィーチャは約 2.6M、シーケンス長 65 の ESM2 650M フィーチャは約 361 KB です。

4. ESM2 推論スクリプトをテストする

推理脚本:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/9/11
"""
import math
import os
import sys
import time

import torch

import esm

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)

from myutils.project_utils import time_elapsed


class Esm2Infer(object):
    """
    推理 ESM2 特征
    """
    def __init__(self):
        print("[Info] 加载模型开始! ")
        s_time = time.time()
        self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        print(f"[Info] vocab: {
      
      self.alphabet.to_dict()}")
        self.batch_converter = self.alphabet.get_batch_converter()
        self.model.eval()  # disables dropout for deterministic results
        print(f"[Info] 加载模型完成! 耗时: {
      
      time_elapsed(s_time, time.time())}")

    def predict(self, data_list):
        """
        数据示例:
        data_list = [
            ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
            ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
            ("protein2 with mask", "KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
            ("protein3",  "K A <mask> I S Q"),
        ]
        """
        print(f"[Info] data_list: {
      
      len(data_list)}")
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data_list)
        print(f"[Info] batch_labels: {
      
      batch_labels}")
        print(f"[Info] batch_tokens: {
      
      batch_tokens}")
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)
        print(f"[Info] batch_lens: {
      
      batch_lens}")  # 有效维数

        # Extract per-residue representations (on CPU)
        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=True)
        token_representations = results["representations"][33]

        # Generate per-sequence representations via averaging
        # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
        sequence_representations = []
        for i, tokens_len in enumerate(batch_lens):
            feat = token_representations[i, 1: tokens_len - 1]
            # embeddings = feat.mean(0)
            # print(f"[Info] idx: {i}, feat: {feat.shape}, embeddings: {embeddings.shape}")
            # sequence_representations.append(embeddings)
            sequence_representations.append(feat)
        return sequence_representations


def main():
    data_list = [
        ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
        ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
        ("protein2 with mask", "KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
        ("protein3", "K A <mask> I S Q"),
    ]
    ei = Esm2Infer()
    ei.predict(data_list)


if __name__ == '__main__':
    main()

おすすめ

転載: blog.csdn.net/u012515223/article/details/132888139