Welcome to follow my CSDN:https://spike.blog.csdn.net/
This article address:https://spike.blog.csdn.net/article/details/134328447
In the process of predicting the structure of protein complexes, templates play an important role, providing a priori information about the three-dimensional structure of the prediction results. In the case of multiple chains, template pairing is required, that is, Template Pair , the core function is template_pair_embedder
, combined with the paper of AlphaFold2, to analyze the characteristics of specific input and output.
Core logic template_pair_embedder()
, input and output feature dimensions:
[CL] TemplateEmbedderMultimer - template_dgram: torch.Size([1, 1102, 1102, 39])
[CL] TemplateEmbedderMultimer - z: torch.Size([1102, 1102, 128])
[CL] TemplateEmbedderMultimer - pseudo_beta_mask: torch.Size([1, 1102])
[CL] TemplateEmbedderMultimer - backbone_mask: torch.Size([1, 1102])
[CL] TemplateEmbedderMultimer - multichain_mask_2d: torch.Size([1102, 1102])
[CL] TemplateEmbedderMultimer - unit_vector: torch.Size([1, 1102, 1102])
[CL] TemplateEmbedderMultimer - pair_act: torch.Size([1, 1102, 1102, 64])
function:
# openfold/model/embedders.py
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t) # 从 c_t 维度 转换 成 c_z 维度,更新 z
template_embeds["template_pair_embedding"] = t
# openfold/model/model.py
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_fa=self.globals.use_fa,
)
z = z + template_embeds["template_pair_embedding"] # line 13 in Alg.2
Logic diagram:
1. template_dgram feature
template_dgram
Feature calculation, the distance between different points and other points is divided into 39 bins in total. The no_bin of Template is 1.25. Calculate 1 value, that is, (50.75 - 3.25) / 38 = 1.25
, that is:
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
Template's no_bin is 1.25 and calculates 1 value, that is (50.75 - 3.25) / 38 = 1.25, and the length is 39.
Then,template_positions
Special expedition, as below:
- import
template_pseudo_beta
givetemplate_pseudo_beta_mask
- Return
template_positions
,Immediatelytemplate_pseudo_beta
目值
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
The processing of pseudo_beta
features comes from openfold/data/data_transforms_multimer.py
, that is:
- Special entry:
template_aatype
,template_all_atom_positions
,template_all_atom_mask
- The principle is: select the coordinates of the CA or CB atom and the Mask
The source code, that is, the calling relationship is as follows:
# 输入特征,模型预测结构
# run_pretrained_openfold.py
processed_feature_dict, _ = feature_processor.process_features(
feature_dict, is_multimer, mode="predict"
)
output_dict = predict_structure_single_dev(
args,
model_name,
current_model,
fasta_path,
processed_feature_dict,
config,
)
# openfold/data/feature_pipeline.py
processed_feature, label = np_example_to_features_multimer(
np_example=raw_features,
config=self.config,
mode=mode,
)
# openfold/data/feature_pipeline.py
features, label = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
cfg.data_module,
)
# openfold/data/input_pipeline_multimer.py
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
# openfold/data/input_pipeline_multimer.py
operators.extend(
[
data_transforms_multimer.make_atom14_positions,
data_transforms_multimer.atom37_to_frames,
data_transforms_multimer.atom37_to_torsion_angles(""),
data_transforms_multimer.make_pseudo_beta(""),
data_transforms_multimer.get_backbone_frames,
data_transforms_multimer.get_chi_angles,
]
)
# openfold/data/data_transforms_multimer.py
def make_pseudo_beta(protein, prefix=""):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ["", "template_"]
(
protein[prefix + "pseudo_beta"],
protein[prefix + "pseudo_beta_mask"],
) = pseudo_beta_fn(
protein["template_aatype" if prefix else "aatype"],
protein[prefix + "all_atom_positions"],
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
)
return protein
# openfold/data/data_transforms_multimer.py
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
if aatype.shape[0] > 0:
is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
else:
pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3)
if all_atom_mask is not None:
if aatype.shape[0] > 0:
pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
else:
pseudo_beta_mask = torch.zeros_like(aatype).float()
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
template_pseudo_beta_mask
: Mask indicating if the beta carbon (alpha carbon for glycine) atom has coordinates for the template at this residue.
template_dgram
Special expedition [1, 1102, 1102, 39]
,Immediately:
2. z Characteristics
z features are passed in directly as input, from protein["target_feat"]
, from protein[]"between_segment_residues"]
, that is:
- Log
target_feat
:torch.Size([1102, 21])
, excluding "-", only including21=20+1
amino acids, including X a> - converts
[1102, 21]
through the linear layer intoc_z
dimensions, that is, 128 dimensions. - is then converted into the LxLxC dimension through the
outer sum
operation. In fact, z is the Pair Representation, that is, the[1102, 1102, 128]
dimension.
# openfold/model/embedders.py
def forward(
self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
use_fa=False,
):
# openfold/model/model.py
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_fa=self.globals.use_fa,
)
# openfold/model/model.py
m, z = self.input_embedder(feats)
# openfold/model/embedders.py#InputEmbedderMultimer
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
# ...
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch) # 计算相对位置
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
# openfold/data/data_transforms_multimer.py
def create_target_feat(batch):
"""Create the target features"""
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(
torch.float32
)
return batch
# openfold/data/input_pipeline_multimer.py
operators.extend(
[
data_transforms_multimer.cast_to_64bit_ints,
# todo: randomly_replace_msa_with_unknown may need to be confirmed and tried in training.
# data_transforms_multimer.randomly_replace_msa_with_unknown(0.0),
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms_multimer.make_atom14_masks,
]
)
Frame diagram of InputEmbedderMultimer:
Among them, 21 amino acids are:
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
z
Special expedition (mean sum max), [1102, 1102, 128]
, Immediately:
3. pseudo_beta_mask and backbone_mask features
pseudo_beta_mask
Features Refer to the source code of the part of template_dgram
: template_pseudo_beta
- Pay attention to the Mask information of CA and CB
# openfold/data/data_transforms_multimer.py
def make_pseudo_beta(protein, prefix=""):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ["", "template_"]
(
protein[prefix + "pseudo_beta"],
protein[prefix + "pseudo_beta_mask"],
) = pseudo_beta_fn(
protein["template_aatype" if prefix else "aatype"],
protein[prefix + "all_atom_positions"],
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
)
return protein
# openfold/data/data_transforms_multimer.py
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
if aatype.shape[0] > 0:
is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
else:
pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3)
if all_atom_mask is not None:
if aatype.shape[0] > 0:
pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
else:
pseudo_beta_mask = torch.zeros_like(aatype).float()
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
Among them,openfold/data/msa_pairing.py#merge_chain_features()
, merge multi-chain features and output Template Feature, that is:
[CL] template features, template_aatype : (4, 1102)
[CL] template features, template_all_atom_positions : (4, 1102, 37, 3)
[CL] template features, template_all_atom_mask : (4, 1102, 37)
Note template_all_atom_mask
, that is, the mask information of 37 atoms.
Among them, 37 atoms (Atom):
{
'N': 0, 'CA': 1, 'C': 2, 'CB': 3, 'O': 4, 'CG': 5, 'CG1': 6, 'CG2': 7, 'OG': 8, 'OG1': 9, 'SG': 10,
'CD': 11, 'CD1': 12, 'CD2': 13, 'ND1': 14, 'ND2': 15, 'OD1': 16, 'OD2': 17, 'SD': 18, 'CE': 19, 'CE1': 20,
'CE2': 21, 'CE3': 22, 'NE': 23, 'NE1': 24, 'NE2': 25, 'OE1': 26, 'OE2': 27, 'CH2': 28, 'NH1': 29, 'NH2': 30,
'OH': 31, 'CZ': 32, 'CZ2': 33, 'CZ3': 34, 'NZ': 35, 'OXT': 36}
backbone_mask
Features, only focus on the Mask information of three types of atoms: N, CA, and C. Please refer to:
- is generally the same as
pseudo_beta_mask
because either the residue is present or not.
# openfold/utils/all_atom_multimer.py
def make_backbone_affine(
positions: geometry.Vec3Array,
mask: torch.Tensor,
aatype: torch.Tensor,
) -> Tuple[geometry.Rigid3Array, torch.Tensor]:
a = rc.atom_order["N"]
b = rc.atom_order["CA"]
c = rc.atom_order["C"]
rigid_mask = mask[..., a] * mask[..., b] * mask[..., c]
rigid = make_transform_from_reference(
a_xyz=positions[..., a],
b_xyz=positions[..., b],
c_xyz=positions[..., c],
)
return rigid, rigid_mask
pseudo_beta_mask
given backbone_mask
homologous, [1, 1102]
, immediate:
4. multichain_mask_2d feature
Very simple, it is the in-chain Mask, source code:
# openfold/model/model.py
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
) # [N_res, N_res]
multichain_mask_2d
, [1102, 1102]
, Immediately:
5. unit_vector feature
unit_vector
is Rot3Array
object, unit vector related to angle, source code:
# openfold/model/embedders.py
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
# openfold/utils/all_atom_multimer.py
def make_backbone_affine(
positions: geometry.Vec3Array,
mask: torch.Tensor,
aatype: torch.Tensor,
) -> Tuple[geometry.Rigid3Array, torch.Tensor]:
a = rc.atom_order["N"]
b = rc.atom_order["CA"]
c = rc.atom_order["C"]
rigid_mask = mask[..., a] * mask[..., b] * mask[..., c]
rigid = make_transform_from_reference(
a_xyz=positions[..., a],
b_xyz=positions[..., b],
c_xyz=positions[..., c],
)
return rigid, rigid_mask
# openfold/utils/all_atom_multimer.py
def make_transform_from_reference(
a_xyz: geometry.Vec3Array, b_xyz: geometry.Vec3Array, c_xyz: geometry.Vec3Array
) -> geometry.Rigid3Array:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, a_xyz - b_xyz)
return geometry.Rigid3Array(rotation, b_xyz)
# openfold/utils/geometry/rotation_matrix.py
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
Explanation: α \alpha of all residues within the local framework of each residueα Unit vector of carbon atom displacement. These local frames are calculated in the same way as the target structure.
The unit vector of the displacement of the alpha carbon atom of all residues within the local frame of each residue. These local frames are computed in the same way as for the target structure.
Calculation logic:
unit_vector
includes 3 components such as x, y, z, [1, 1102, 1102, 3]
, that is: