PSP - 蛋白质结构预测 OpenFold Multimer 重构训练模型的数据加载

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132602155

PDB

OpenFold Multimer 在训练过程的数据加载时,需要将 MSA 与 Template 信息转换成 Feature,再进行训练,这样速度较慢。通过修改数据集类 OpenFoldSingleMultimerDataset__getitem__ 方法,可以加速训练过程。


1. 准备训练数据

在训练过程中,需要读取 mmcif_cache.json 文件,数据结构如下:

{
    
    
    "4ewn": {
    
    
        "release_date": "2012-12-05",
        "chain_ids": [
            "D"
        ],
        "seqs": [
            "MLAKRI..."
        ],
        "no_chains": 1,
        "resolution": 1.9
    },
    "5m9r": {
    
    
        "release_date": "2017-02-22",
        "chain_ids": [
            "A",
            "B"
        ],
        "seqs": [
            "MQDNS...",
            "MQDNS..."
        ],
        "no_chains": 2,
        "resolution": 1.44
    },
#...
}  

当前的训练数据格式,例如 train_200_mini.csv,如下:

pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
7m5z,"A,B",3.06,2021-10-06,"LEDVV...,QNKLE...","263,264","protein,protein",[pdb_path]/structures/m5/pdb7m5z.ent.gz
7k05,"A,B",1.85,2021-10-06,"MSFPP...,MSFPP...","200,200","protein,protein",[pdb_path]/structures/k0/pdb7k05.ent.gz
# ...

同时需要将 feature 的路径,也加入到训练文件 mmcif_cache.json 中,进而,通过预读文件,进行特征抽取,即:

[your folder]/multimer_train/features

使用特征文件夹中,已经预处理之后的特征 features.pkl,进行训练即可:

# 单个文件夹内容
chain_id_map.json
features.pkl
sequences.fasta

训练文件的转换命令,如下:

python openfold_scripts/main_mmcif_cache_transfer.py -i data/train_200_mini.csv -f [your folder]/multimer_train/features -o mydata/openfold/mmcif_cache_mini.json

源码如下:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/8/31
"""
import argparse
import json
import os
import sys
from pathlib import Path

import pandas as pd
from tqdm import tqdm

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)


class MmcifCacheTransfer(object):
    """
    训练 CSV 转换成 OpenFold 的 mmcif_cache.json 格式
    """
    def __init__(self):
        pass

    @staticmethod
    def process(input_path, feature_dir, output_path):
        print(f"[Info] 输入文件: {
      
      input_path}")
        print(f"[Info] 特征文件夹: {
      
      feature_dir}")
        print(f"[Info] 输出文件: {
      
      output_path}")
        assert os.path.isfile(input_path)
        df = pd.read_csv(input_path)
        print(f"[Info] 输入样本: {
      
      len(df)}")
        mmcif_cache_dict = dict()
        # pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
        for _, row in tqdm(df.iterrows(), "[Info] pdb"):
            pdb_id = row["pdb_id"]
            release_date = row["release_date"]
            chain_ids = row["chain_id"].split(",")
            seqs = row["seq"].split(",")
            no_chains = len(chain_ids)
            resolution = float(row["resolution"])
            feature_folder = os.path.join(feature_dir, pdb_id[1:3], f"pdb{
      
      pdb_id}_{
      
      ''.join(chain_ids)}")
            pdb_dict = {
    
    
                "release_date": str(release_date),
                "chain_ids": chain_ids,
                "seqs": seqs,
                "no_chains": no_chains,
                "resolution": resolution,
                "feature_folder": feature_folder
            }
            mmcif_cache_dict[pdb_id] = pdb_dict
        with open(output_path, "w") as fp:
            fp.write(json.dumps(mmcif_cache_dict, indent=4))
        print(f"[Info] 全部处理完成: {
      
      output_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i",
        "--input-path",
        help="the input file path.",
        type=Path,
        required=True,
    )
    parser.add_argument(
        "-f",
        "--feature-dir",
        help="the preprocess feature dir.",
        type=Path,
        required=True
    )
    parser.add_argument(
        "-o",
        "--output-path",
        help="the output file path.",
        type=Path,
        required=True
    )

    args = parser.parse_args()

    input_path = str(args.input_path)
    feature_dir = str(args.feature_dir)
    output_path = str(args.output_path)
    assert os.path.isfile(input_path)

    # from root_dir import ROOT_DIR, DATA_DIR
    # input_path = os.path.join(ROOT_DIR, "data", "train_200_mini.csv")
    # output_path = os.path.join(DATA_DIR, "openfold", "mmcif_cache_mini.json")
    mct = MmcifCacheTransfer()
    mct.process(input_path, feature_dir, output_path)


if __name__ == '__main__':
    main()

2. 加载训练数据

OpenFold Multimer 的特征读取逻辑,在 openfold/data/data_modules.py#OpenFoldSingleMultimerDataset() 中,即:

if self.mode == 'train' or self.mode == 'eval':
    path = os.path.join(self.data_dir, f"{
      
      mmcif_id}")
    ext = None
    for e in self.supported_exts:
        if os.path.exists(path + e):
            ext = e
            break
    if ext is None:
        raise ValueError("Invalid file type")
    # TODO: Add pdb and core exts to data_pipeline for multimer
    path += ext
    if ext == ".cif":
        data = self._parse_mmcif(
            path, mmcif_id, self.alignment_dir, alignment_index)
    else:
        raise ValueError("Extension branch missing")
else:
    path = os.path.join(self.data_dir, f"{
      
      mmcif_id}.fasta")
    data = self.data_pipeline.process_fasta(
        fasta_path=path,
        alignment_dir=self.alignment_dir)

修改成直接加载 Feature 的形式,即:

if self.mode == 'train' or self.mode == 'eval':
    # 训练或评估时,使用预处理的特征
    feat_folder = self.mmcif_data_cache[mmcif_id]['feature_folder']
    feat_path = os.path.join(feat_folder, "features.pkl")
    # logger.info(f"[Info] feat_path: {feat_path}")
    data = {
    
    }
    with open(feat_path, "rb") as f:
        feat_dict = pickle.load(f)
    data.update(feat_dict)
    # logger.info(f"[Info] data: {data.keys()}")
else:
    path = os.path.join(self.data_dir, f"{
      
      mmcif_id}.fasta")
    data = self.data_pipeline.process_fasta(
        fasta_path=path,
        alignment_dir=self.alignment_dir)

同时,还需要修改训练数据总数:

def __len__(self):
    # 数据部分都由 mmcif_data_cache 提供
    # return len(self._chain_ids)
    return len(self.mmcif_data_cache.keys)

3. 配置模型训练

模型训练的参数,如下:

python3 train_openfold.py \
    --train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
    --train_alignment_dir mydata/alignment_dir/ \
    --train_mmcif_data_cache_path [your folder]/multimer_train/openfold_cache/mmcif_cache_mini.json \
    --template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
	--output_dir mydata/output_dir/ \
    --max_template_date "2021-10-10" \
    --config_preset "model_1_multimer_v3" \
    --template_release_dates_cache_path mmcif_cache.json \
    --precision bf16 \
    --gpus 1 \
    --replace_sampler_ddp=True \
    --seed 42 \
    --deepspeed_config_path deepspeed_config.json \
    --checkpoint_every_epoch \
    --obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat

模型训练占用显存较多,V100 目前无法支持,调低 crop_size 与 num_workers,降低资源占用,配置位于 openfold/config.py 中,即:

# crop_size
elif "multimer" in name:
    c.update(multimer_config_update.copy_and_resolve_references())
    c.data.train.crop_size = 64  # TODO: 用于测试

# num_workers
"data_module": {
    
    
    "use_small_bfd": False,
    "data_loaders": {
    
    
        "batch_size": 1,
        # "num_workers": 16,
        "num_workers": 2,  # TODO: 用于测试
        "pin_memory": True,
    },
},

其中,crop_size = 64 占用显存约是 5141MiB

训练日志,如下:

Epoch 0:   0%|                                 | 0/199 [00:00<?, ?it/s]INFO:openfold/data/data_modules.py:mmcif_id is: 7poc, idx: 148 and has 4 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7u49, idx: 97 and has 3 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7z7h, idx: 114 and has 6 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7nup, idx: 111 and has 4 chains
cum_loss: tensor([84.1698], device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>) losses: {
    
    'distogram': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'experimentally_resolved': tensor(0.6914, device='cuda:0'), 'fape': tensor(1.6598, device='cuda:0', dtype=torch.float64), 'plddt_loss': tensor(3.9062, device='cuda:0', dtype=torch.float64), 'masked_msa': tensor(3.0938, device='cuda:0'), 'supervised_chi': tensor(0.7941, device='cuda:0', dtype=torch.float64), 'violation': tensor(3.6495, device='cuda:0'), 'tm': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'chain_center_of_mass': tensor([1.3754], device='cuda:0', dtype=torch.float64), 'unscaled_loss': tensor([10.5212], device='cuda:0', dtype=torch.float64), 'loss': tensor([84.1698], device='cuda:0', dtype=torch.float64)}
Epoch 0:   1%|| 1/199 [02:55<9:38:06, 175.18s/it, loss=84.2, v_num=]

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/132602155
psp
今日推荐