PSP - 蛋白质结构预测 OpenFold Multimer 训练模型的数据加载

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132597659

OpenFold

OpenFold Multimer 是基于深度学习的方法,预测蛋白质的多聚体结构和相互作用。利用大规模的蛋白质序列和结构数据,以及先进的神经网络架构,来学习蛋白质的表示和特征。可以处理不同类型的多聚体,包括同源和异源多聚体,以及复杂的蛋白质-蛋白质相互作用网络。OpenFold Multimer 的目标是为生物学家提供一个快速、准确和易用的工具,来探索蛋白质的多聚体功能和机制。

训练参数:

python3 train_openfold.py \
  --train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
  --train_alignment_dir mydata/alignment_dir/ \
  --train_mmcif_data_cache_path mmcif_cache.json \
  --template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
  --output_dir mydata/output_dir/ \
  --max_template_date "2021-10-10" \
  --config_preset "model_1_multimer_v3" \
  --template_release_dates_cache_path mmcif_cache.json \
  --precision bf16 \
  --gpus 1 \
  --replace_sampler_ddp=True \
  --seed 42 \
  --deepspeed_config_path deepspeed_config.json \
  --checkpoint_every_epoch \
  --train_chain_data_cache_path chain_data_cache.json \
  --obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat


1. train_alignment_dir

核心关注 train_alignment_dir,这部分是缓存的预处理特征,调用路径如下:

  • train_openfold.pyargs 参数,传入 OpenFoldMultimerDataModule
  • 再由 dataset_gen() 方法,也就是 OpenFoldSingleMultimerDataset 类,接收
  • 参数由 alignment_dir=self.train_alignment_dir,转换成 alignment_dir
  • 再由 OpenFoldMultimerDataModule 类,调用 OpenFoldSingleMultimerDataset

# train_openfold.py
# ...
if "multimer" in args.config_preset:
    data_module = OpenFoldMultimerDataModule(
        config=config.data,
        batch_seed=args.seed,
        **vars(args))
# ...

# openfold/data/data_modules.py#OpenFoldMultimerDataModule
# ...
if self.training_mode:
    train_dataset = dataset_gen(
        data_dir=self.train_data_dir,
        mmcif_data_cache_path=self.train_mmcif_data_cache_path,
        alignment_dir=self.train_alignment_dir,
        filter_path=self.train_filter_path,
        max_template_hits=self.config.train.max_template_hits,
        shuffle_top_k_prefiltered=
            self.config.train.shuffle_top_k_prefiltered,
        treat_pdb_as_distillation=False,
        mode="train",
        alignment_index=self.alignment_index,)
# ...

OpenFoldSingleMultimerDataset 类中,alignment_dir 用于 _chain_ids 的赋值,即

if alignment_index is not None:
    self._chain_ids = list(alignment_index.keys())
else:
    self._chain_ids = list(os.listdir(alignment_dir))

alignment_index_path 支持作为参数,传入,默认是空,相关描述如下,核心是先编译成单个文件,再读入,可以提升效率:

In cases where it may be burdensome to create separate files for each chain’s alignments, alignment directories can be consolidated using the scripts in scripts/alignment_db_scripts/. First, run create_alignment_db.py to consolidate an alignment directory into a pair of database and index files. Once all alignment directories (or shards of a single alignment directory) have been compiled, unify the indices with unify_alignment_db_indices.py. The resulting index, super.index, can be passed to the training script flags containing the phrase alignment_index. In this scenario, the alignment_dir flags instead represent the directory containing the compiled alignment databases. Both the training and distillation datasets can be compiled in this way. Anecdotally, this can speed up training in I/O-bottlenecked environments.

其中,self._chain_ids 是全部的训练集:

def __len__(self):
    return len(self._chain_ids) 

设置 logger 日志:

import logging
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)

训练数据的遍历参数:

def __getitem__(self, idx):
    mmcif_id = self.idx_to_mmcif_id(idx)
    chains = self.mmcif_data_cache[mmcif_id]['chain_ids']

根据输出,组织训练数据:

mmcif_id is: 5ykn, idx: 8580 and has 1 chains
mmcif_id is: 2lna, idx: 3848 and has 1 chains
mmcif_id is: 7rrp, idx: 8447 and has 24 chains
mmcif_id is: 6k8h, idx: 7870 and has 2 chains
...

2. OpenFoldSingleMultimerDataset

具体分析 OpenFoldSingleMultimerDataset 类。在 __getitem__ 方法中,遍历训练样本,核心关注:

  • self.idx_to_mmcif_id() 函数调用 self._mmcifs[idx]
  • 2个关键变量,self._mmcifsself.mmcif_data_cache,而且两者的 keys 要保持一致。

即:

def __getitem__(self, idx):
    mmcif_id = self.idx_to_mmcif_id(idx)
    chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
    print(f"mmcif_id is: {
      
      mmcif_id}, idx: {
      
      idx} and has {
      
      len(chains)} chains")

关于 self._mmcifs 数据,调用 mmcif_data_cache_path -> self.mmcif_data_cache -> self._mmcifs

  • mmcif_data_cache_path 来源于预处理的过程

即:

# ...
logger.info(f"[CL] mmcif_data_cache_path: {
      
      mmcif_data_cache_path}")
if mmcif_data_cache_path is not None:
    with open(mmcif_data_cache_path, "r") as infile:
        self.mmcif_data_cache = json.load(infile)
    assert isinstance(self.mmcif_data_cache, dict)
# ...
if self.mmcif_data_cache is not None:
    self._mmcifs = list(self.mmcif_data_cache.keys())
    self._mmcif_id_to_idx_dict = {
    
    mmcif: i for i, mmcif in enumerate(self._mmcifs)}

其中 mmcif_cache.json 的文件数据,包括PDB信息,即:

{
    
    
    "4ewn": {
    
    
        "release_date": "2012-12-05",
        "chain_ids": ["D"],
        "seqs": [
            "MLAKRI..."
        ],
        "no_chains": 1,
        "resolution": 1.9
    },
    "5m9r": {
    
    
        "release_date": "2017-02-22",
        "chain_ids": ["A", "B"],
        "seqs": [
            "MQDNS...",
            "MQDNS..."
        ],
        "no_chains": 2,
        "resolution": 1.44
    },
# ...

BugFix: 增加 train_mmcif_data_cache_path 参数

--train_mmcif_data_cache_path mmcif_cache.json

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/132597659
psp
今日推荐