欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134293267
在蛋白质复合物结构预测 AlphaFold Multimer 框架中,MSA 特征与 模版 (Template) 特征,共同作为 Evoformer 的输入特征初始化,引导之后的结构预测过程。MSA 有 MSA Pairing 信息,Template 也同样包括 Pairing,跨链 Template 特征的构建也至关重要。
模版 (Template) 特征调用顺序:
- 预测 fasta 序列:
run_pretrained_openfold.py#main()
- 处理多链特征:
openfold/data/data_pipeline.py#multimer_process_single_chain()
- 处理单链特征:
openfold/data/data_pipeline.py#_process_single_chain()
- 调用模版搜索:
openfold/data/data_pipeline.py#get_template_features()
源码:
# ...
feature_dict, is_multimer = data_processor.process_fasta(
fasta_path=fasta_path,
msa_output_path=fasta_output_dir,
tmpl_pdb_list=tmpl_pdb_list,
specific_tmpl_mode=args.specific_tmpl_mode,
is_multimer=global_is_multimer,
use_complex_template=args.use_complex_template,
use_extra_msa=args.use_extra_msa,
)
# ...
chain_features = self._process_single_chain(
sequence=fasta_chain.sequence,
description=fasta_chain.description,
is_multimer=is_multimer,
tmpl_pdb_list=tmpl_pdb_list,
specific_tmpl_mode=specific_tmpl_mode,
chain_msa_folder=chain_msa_folder,
)
# ...
template_features = self.get_template_features(
input_sequence,
msas,
self.method_config,
a3m_lines,
is_multimer=is_multimer,
tmpl_pdb_list=tmpl_pdb_list,
specific_tmpl_mode=specific_tmpl_mode,
)
日志:
template param, min_score: 0.1, max_score: 0.95, release_date: 2023-01-01
# 序列长度 143
search_tmpl: MSKVETGDQGYTVVQSKYKKAVEQLQKGLLDGEIKIFFEGTLASTIYCLHKVDNKLDNLGDGDYVDFLIITKLRILNAKEETIDIDASSSKTAQDLAKKYVFNKTDLNTLYRVLNGDEADTNRLVEEVSGKYQVVLYPEGKRV, search_tmpl_mode: all, is_multimer: True, tmpl_list: 20, tmpl_score: [0.4195804195804196, 0.4195804195804196, 0.4125874125874126, 0.4125874125874126, 0.4125874125874126, 0.4125874125874126, 0.4125874125874126, 0.4125874125874126, 0.410958904109589, 0.4105960264900662, 0.4097222222222222, 0.4097222222222222, 0.4097222222222222, 0.40816326530612246, 0.4068965517241379, 0.4068965517241379, 0.4068965517241379, 0.4068965517241379, 0.4068965517241379, 0.4066666666666667]
# 模版特征
template_features: dict_keys(['template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', 'template_domain_names', 'template_sequence', 'template_sum_probs'])
key: template_aatype, shape: (20, 143, 22)
key: template_all_atom_masks, shape: (20, 143, 37)
key: template_all_atom_positions, shape: (20, 143, 37, 3)
key: template_domain_names, shape: (20,)
key: template_sequence, shape: (20,)
key: template_sum_probs, shape: (20, 1)
其中,37 个原子(Atom):
{
'N': 0, 'CA': 1, 'C': 2, 'CB': 3, 'O': 4, 'CG': 5, 'CG1': 6, 'CG2': 7, 'OG': 8, 'OG1': 9, 'SG': 10,
'CD': 11, 'CD1': 12, 'CD2': 13, 'ND1': 14, 'ND2': 15, 'OD1': 16, 'OD2': 17, 'SD': 18, 'CE': 19, 'CE1': 20,
'CE2': 21, 'CE3': 22, 'NE': 23, 'NE1': 24, 'NE2': 25, 'OE1': 26, 'OE2': 27, 'CH2': 28, 'NH1': 29, 'NH2': 30,
'OH': 31, 'CZ': 32, 'CZ2': 33, 'CZ3': 34, 'NZ': 35, 'OXT': 36}
其中,22 个氨基酸 (AA):
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
其中,openfold/data/msa_pairing.py#merge_chain_features()
, 合并(merge)多链特征,输出 Template Feature,即:
[CL] template features, template_aatype : (4, 1102)
[CL] template features, template_all_atom_positions : (4, 1102, 37, 3)
[CL] template features, template_all_atom_mask : (4, 1102, 37)
输入网络的模版特征,如下:
[CL] template_feats, template_aatype: torch.Size([4, 1102])
[CL] template_feats, template_all_atom_mask: torch.Size([4, 1102, 37])
[CL] template_feats, template_all_atom_positions: torch.Size([4, 1102, 37, 3])
[CL] template_feats, template_mask: torch.Size([4])
[CL] template_feats, template_pseudo_beta: torch.Size([4, 1102, 3])
[CL] template_feats, template_pseudo_beta_mask: torch.Size([4, 1102])
[CL] template_feats, template_torsion_angles_sin_cos: torch.Size([4, 1102, 7, 2])
[CL] template_feats, template_alt_torsion_angles_sin_cos: torch.Size([4, 1102, 7, 2])
[CL] template_feats, template_torsion_angles_mask: torch.Size([4, 1102, 7])
网络模版特征调用逻辑,openfold/model/model.py#AlphaFold()
,即:
self.template_embedder = TemplateEmbedderMultimer(
template_config,
model_version=model_version,
)
模版特征的相关维度,核心部分是pair_act
,即:
[CL] TemplateEmbedderMultimer - n_templ: 4
[CL] TemplateEmbedderMultimer - template_positions: torch.Size([1, 1102, 3])
[CL] TemplateEmbedderMultimer - template_dgram: torch.Size([1, 1102, 1102, 39])
[CL] TemplateEmbedderMultimer - raw_atom_pos: torch.Size([1, 1102, 37, 3])
[CL] TemplateEmbedderMultimer - atom_pos: torch.Size([1, 1102, 37])
[CL] TemplateEmbedderMultimer - unit_vector: torch.Size([1, 1102, 1102])
[CL] TemplateEmbedderMultimer - pair_act: torch.Size([1, 1102, 1102, 64])
Template 的 no_bin 是 1.25 计算 1 个值,即 (50.75 - 3.25) / 38 = 1.25,长度是 39。
部分数据来源:
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
源码:
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {
}
act = 0.0
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
logger.info(f"[CL] TemplateEmbedderMultimer - template_dgram: {
template_dgram.shape}")
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"],
22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
logger.info(f"[CL] TemplateEmbedderMultimer - raw_atom_pos: {
raw_atom_pos.shape}")
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
logger.info(f"[CL] TemplateEmbedderMultimer - atom_pos: {
atom_pos.shape}")
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
logger.info(f"[CL] TemplateEmbedderMultimer - unit_vector: {
unit_vector.shape}")
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
logger.info(f"[CL] TemplateEmbedderMultimer - pair_act: {
pair_act.shape}")
single_template_embeds["template_pair_embedding"] = pair_act
最核心的特征 template_dgram
,即模版残基距离的量化:
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
Distogram 效果如下: