PSP - Protein complex structure prediction Template Pair Logic feature analysis

Welcome to follow my CSDN:https://spike.blog.csdn.net/
This article address:https://spike.blog.csdn.net/article/details/134328447

In the process of predicting the structure of protein complexes, templates play an important role, providing a priori information about the three-dimensional structure of the prediction results. In the case of multiple chains, template pairing is required, that is, Template Pair , the core function is template_pair_embedder, combined with the paper of AlphaFold2, to analyze the characteristics of specific input and output.

Core logic template_pair_embedder(), input and output feature dimensions:

[CL] TemplateEmbedderMultimer - template_dgram: torch.Size([1, 1102, 1102, 39])
[CL] TemplateEmbedderMultimer - z: torch.Size([1102, 1102, 128])
[CL] TemplateEmbedderMultimer - pseudo_beta_mask: torch.Size([1, 1102])
[CL] TemplateEmbedderMultimer - backbone_mask: torch.Size([1, 1102])
[CL] TemplateEmbedderMultimer - multichain_mask_2d: torch.Size([1102, 1102])
[CL] TemplateEmbedderMultimer - unit_vector: torch.Size([1, 1102, 1102])
[CL] TemplateEmbedderMultimer - pair_act: torch.Size([1, 1102, 1102, 64])

function:

# openfold/model/embedders.py
pair_act = self.template_pair_embedder(
    template_dgram,
    aatype_one_hot,
    z,
    pseudo_beta_mask,
    backbone_mask,
    multichain_mask_2d,
    unit_vector,
)

t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)  # 从 c_t 维度 转换 成 c_z 维度,更新 z
template_embeds["template_pair_embedding"] = t

# openfold/model/model.py
template_embeds = self.template_embedder(
    template_feats,
    z,
    pair_mask.to(dtype=z.dtype),
    no_batch_dims,
    chunk_size=self.globals.chunk_size,
    multichain_mask_2d=multichain_mask_2d,
    use_fa=self.globals.use_fa,
)
z = z + template_embeds["template_pair_embedding"]  # line 13 in Alg.2

Logic diagram:

Template Pairing


1. template_dgram feature

template_dgramFeature calculation, the distance between different points and other points is divided into 39 bins in total. The no_bin of Template is 1.25. Calculate 1 value, that is, (50.75 - 3.25) / 38 = 1.25, that is:

template_dgram = dgram_from_positions(
    template_positions,
    inf=self.config.inf,
    **self.config.distogram,
)

def dgram_from_positions(
    pos: torch.Tensor,
    min_bin: float = 3.25,
    max_bin: float = 50.75,
    no_bins: float = 39,
    inf: float = 1e8,
):
    dgram = torch.sum(
        (pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
    )
    lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
    upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
    dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)

    return dgram

Template's no_bin is 1.25 and calculates 1 value, that is (50.75 - 3.25) / 38 = 1.25, and the length is 39.

Then,template_positions Special expedition, as below:

  • import template_pseudo_beta give template_pseudo_beta_mask
  • Return template_positions,Immediately template_pseudo_beta 目值
template_positions, pseudo_beta_mask = (
    single_template_feats["template_pseudo_beta"],
    single_template_feats["template_pseudo_beta_mask"],
)

The processing of pseudo_beta features comes from openfold/data/data_transforms_multimer.py, that is:

  • Special entry: template_aatype, template_all_atom_positions, template_all_atom_mask
  • The principle is: select the coordinates of the CA or CB atom and the Mask

The source code, that is, the calling relationship is as follows:

# 输入特征,模型预测结构
# run_pretrained_openfold.py
processed_feature_dict, _ = feature_processor.process_features(
    feature_dict, is_multimer, mode="predict"
)
output_dict = predict_structure_single_dev(
    args,
    model_name,
    current_model,
    fasta_path,
    processed_feature_dict,
    config,
)
  
# openfold/data/feature_pipeline.py
processed_feature, label = np_example_to_features_multimer(
    np_example=raw_features,
    config=self.config,
    mode=mode,
)

# openfold/data/feature_pipeline.py
features, label = input_pipeline_multimer.process_tensors_from_config(
    tensor_dict,
    cfg.common,
    cfg[mode],
    cfg.data_module,
)

# openfold/data/input_pipeline_multimer.py
nonensembled = nonensembled_transform_fns(
    common_cfg,
    mode_cfg,
)
tensors = compose(nonensembled)(tensors)

# openfold/data/input_pipeline_multimer.py
operators.extend(
    [
        data_transforms_multimer.make_atom14_positions,
        data_transforms_multimer.atom37_to_frames,
        data_transforms_multimer.atom37_to_torsion_angles(""),
        data_transforms_multimer.make_pseudo_beta(""),
        data_transforms_multimer.get_backbone_frames,
        data_transforms_multimer.get_chi_angles,
    ]
)

# openfold/data/data_transforms_multimer.py
def make_pseudo_beta(protein, prefix=""):
    """Create pseudo-beta (alpha for glycine) position and mask."""
    assert prefix in ["", "template_"]
    (
        protein[prefix + "pseudo_beta"],
        protein[prefix + "pseudo_beta_mask"],
    ) = pseudo_beta_fn(
        protein["template_aatype" if prefix else "aatype"],
        protein[prefix + "all_atom_positions"],
        protein["template_all_atom_mask" if prefix else "all_atom_mask"],
    )
    return protein
  
# openfold/data/data_transforms_multimer.py
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
    """Create pseudo beta features."""
    if aatype.shape[0] > 0:
        is_gly = torch.eq(aatype, rc.restype_order["G"])
        ca_idx = rc.atom_order["CA"]
        cb_idx = rc.atom_order["CB"]
        pseudo_beta = torch.where(
            torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
            all_atom_positions[..., ca_idx, :],
            all_atom_positions[..., cb_idx, :],
        )
    else:
        pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3)
    if all_atom_mask is not None:
        if aatype.shape[0] > 0:
            pseudo_beta_mask = torch.where(
                is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
            )
        else:
            pseudo_beta_mask = torch.zeros_like(aatype).float()
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta

template_pseudo_beta_mask: Mask indicating if the beta carbon (alpha carbon for glycine) atom has coordinates for the template at this residue.

template_dgramSpecial expedition [1, 1102, 1102, 39],Immediately:

template_dgram


2. z Characteristics

z features are passed in directly as input, from protein["target_feat"], from protein[]"between_segment_residues"], that is:

  • Logtarget_feat: torch.Size([1102, 21]), excluding "-", only including 21=20+1 amino acids, including X a>
  • converts [1102, 21] through the linear layer into c_z dimensions, that is, 128 dimensions.
  • is then converted into the LxLxC dimension through the outer sum operation. In fact, z is the Pair Representation, that is, the [1102, 1102, 128] dimension.
# openfold/model/embedders.py
def forward(
    self,
    batch,
    z,
    padding_mask_2d,
    templ_dim,
    chunk_size,
    multichain_mask_2d,
    use_fa=False,
):
  
# openfold/model/model.py
template_embeds = self.template_embedder(
    template_feats,
    z,
    pair_mask.to(dtype=z.dtype),
    no_batch_dims,
    chunk_size=self.globals.chunk_size,
    multichain_mask_2d=multichain_mask_2d,
    use_fa=self.globals.use_fa,
)

# openfold/model/model.py
m, z = self.input_embedder(feats)

# openfold/model/embedders.py#InputEmbedderMultimer
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    # ...
    Returns:
        msa_emb:
            [*, N_clust, N_res, C_m] MSA embedding
        pair_emb:
            [*, N_res, N_res, C_z] pair embedding

    """
    tf = batch["target_feat"]
    msa = batch["msa_feat"]

    # [*, N_res, c_z]
    tf_emb_i = self.linear_tf_z_i(tf)
    tf_emb_j = self.linear_tf_z_j(tf)

    # [*, N_res, N_res, c_z]
    pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
    pair_emb = pair_emb + self.relpos(batch)  # 计算相对位置

    # [*, N_clust, N_res, c_m]
    n_clust = msa.shape[-3]
    tf_m = (
        self.linear_tf_m(tf)
        .unsqueeze(-3)
        .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
    )
    msa_emb = self.linear_msa_m(msa) + tf_m

    return msa_emb, pair_emb
  
# openfold/data/data_transforms_multimer.py
def create_target_feat(batch):
    """Create the target features"""
    batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(
        torch.float32
    )
    return batch

# openfold/data/input_pipeline_multimer.py
operators.extend(
    [
        data_transforms_multimer.cast_to_64bit_ints,
        # todo: randomly_replace_msa_with_unknown may need to be confirmed and tried in training.
        # data_transforms_multimer.randomly_replace_msa_with_unknown(0.0),
        data_transforms_multimer.make_msa_profile,
        data_transforms_multimer.create_target_feat,
        data_transforms_multimer.make_atom14_masks,
    ]
)

Frame diagram of InputEmbedderMultimer:
InputEmbedderMultimer

Among them, 21 amino acids are:

ID_TO_HHBLITS_AA = {
    
    
    0: "A",
    1: "C",  # Also U.
    2: "D",  # Also B.
    3: "E",  # Also Z.
    4: "F",
    5: "G",
    6: "H",
    7: "I",
    8: "K",
    9: "L",
    10: "M",
    11: "N",
    12: "P",
    13: "Q",
    14: "R",
    15: "S",
    16: "T",
    17: "V",
    18: "W",
    19: "Y",
    20: "X",  # Includes J and O.
    21: "-",
}

z Special expedition (mean sum max), [1102, 1102, 128], Immediately:

With


3. pseudo_beta_mask and backbone_mask features

pseudo_beta_mask Features Refer to the source code of the part of template_dgram: template_pseudo_beta

  • Pay attention to the Mask information of CA and CB
# openfold/data/data_transforms_multimer.py
def make_pseudo_beta(protein, prefix=""):
    """Create pseudo-beta (alpha for glycine) position and mask."""
    assert prefix in ["", "template_"]
    (
        protein[prefix + "pseudo_beta"],
        protein[prefix + "pseudo_beta_mask"],
    ) = pseudo_beta_fn(
        protein["template_aatype" if prefix else "aatype"],
        protein[prefix + "all_atom_positions"],
        protein["template_all_atom_mask" if prefix else "all_atom_mask"],
    )
    return protein
  
# openfold/data/data_transforms_multimer.py
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
    """Create pseudo beta features."""
    if aatype.shape[0] > 0:
        is_gly = torch.eq(aatype, rc.restype_order["G"])
        ca_idx = rc.atom_order["CA"]
        cb_idx = rc.atom_order["CB"]
        pseudo_beta = torch.where(
            torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
            all_atom_positions[..., ca_idx, :],
            all_atom_positions[..., cb_idx, :],
        )
    else:
        pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3)
    if all_atom_mask is not None:
        if aatype.shape[0] > 0:
            pseudo_beta_mask = torch.where(
                is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
            )
        else:
            pseudo_beta_mask = torch.zeros_like(aatype).float()
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta

Among them,openfold/data/msa_pairing.py#merge_chain_features(), merge multi-chain features and output Template Feature, that is:

[CL] template features, template_aatype : (4, 1102)
[CL] template features, template_all_atom_positions : (4, 1102, 37, 3)
[CL] template features, template_all_atom_mask : (4, 1102, 37)

Note template_all_atom_mask, that is, the mask information of 37 atoms.

Among them, 37 atoms (Atom):

{
    
    'N': 0, 'CA': 1, 'C': 2, 'CB': 3, 'O': 4, 'CG': 5, 'CG1': 6, 'CG2': 7, 'OG': 8, 'OG1': 9, 'SG': 10, 
'CD': 11, 'CD1': 12, 'CD2': 13, 'ND1': 14, 'ND2': 15, 'OD1': 16, 'OD2': 17, 'SD': 18, 'CE': 19, 'CE1': 20, 
'CE2': 21, 'CE3': 22, 'NE': 23, 'NE1': 24, 'NE2': 25, 'OE1': 26, 'OE2': 27, 'CH2': 28, 'NH1': 29, 'NH2': 30, 
'OH': 31, 'CZ': 32, 'CZ2': 33, 'CZ3': 34, 'NZ': 35, 'OXT': 36}

backbone_maskFeatures, only focus on the Mask information of three types of atoms: N, CA, and C. Please refer to:

  • is generally the same as pseudo_beta_mask because either the residue is present or not.
# openfold/utils/all_atom_multimer.py
def make_backbone_affine(
    positions: geometry.Vec3Array,
    mask: torch.Tensor,
    aatype: torch.Tensor,
) -> Tuple[geometry.Rigid3Array, torch.Tensor]:
    a = rc.atom_order["N"]
    b = rc.atom_order["CA"]
    c = rc.atom_order["C"]

    rigid_mask = mask[..., a] * mask[..., b] * mask[..., c]

    rigid = make_transform_from_reference(
        a_xyz=positions[..., a],
        b_xyz=positions[..., b],
        c_xyz=positions[..., c],
    )

    return rigid, rigid_mask

pseudo_beta_mask given backbone_mask homologous, [1, 1102], immediate:

mask


4. multichain_mask_2d feature

Very simple, it is the in-chain Mask, source code:

# openfold/model/model.py
multichain_mask_2d = (
    asym_id[..., None] == asym_id[..., None, :]
)  # [N_res, N_res]

multichain_mask_2d, [1102, 1102], Immediately:

multichain_mask_2d


5. unit_vector feature

unit_vector is Rot3Array object, unit vector related to angle, source code:

# openfold/model/embedders.py
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
    atom_pos,
    single_template_feats["template_all_atom_mask"],
    single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()

# openfold/utils/all_atom_multimer.py
def make_backbone_affine(
    positions: geometry.Vec3Array,
    mask: torch.Tensor,
    aatype: torch.Tensor,
) -> Tuple[geometry.Rigid3Array, torch.Tensor]:
    a = rc.atom_order["N"]
    b = rc.atom_order["CA"]
    c = rc.atom_order["C"]

    rigid_mask = mask[..., a] * mask[..., b] * mask[..., c]

    rigid = make_transform_from_reference(
        a_xyz=positions[..., a],
        b_xyz=positions[..., b],
        c_xyz=positions[..., c],
    )

    return rigid, rigid_mask

# openfold/utils/all_atom_multimer.py
def make_transform_from_reference(
    a_xyz: geometry.Vec3Array, b_xyz: geometry.Vec3Array, c_xyz: geometry.Vec3Array
) -> geometry.Rigid3Array:
    """Returns rotation and translation matrices to convert from reference.

    Note that this method does not take care of symmetries. If you provide the
    coordinates in the non-standard way, the A atom will end up in the negative
    y-axis rather than in the positive y-axis. You need to take care of such
    cases in your code.

    Args:
        a_xyz: A Vec3Array.
        b_xyz: A Vec3Array.
        c_xyz: A Vec3Array.

    Returns:
        A Rigid3Array which, when applied to coordinates in a canonicalized
        reference frame, will give coordinates approximately equal
        the original coordinates (in the global frame).
    """
    rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, a_xyz - b_xyz)
    return geometry.Rigid3Array(rotation, b_xyz)

# openfold/utils/geometry/rotation_matrix.py
@classmethod
def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array) -> Rot3Array:
    """Construct Rot3Array from two Vectors.

    Rot3Array is constructed such that in the corresponding frame 'e0' lies on
    the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.

    Args:
        e0: Vector
        e1: Vector
    Returns:
        Rot3Array
    """
    # Normalize the unit vector for the x-axis, e0.
    e0 = e0.normalized()
    # make e1 perpendicular to e0.
    c = e1.dot(e0)
    e1 = (e1 - c * e0).normalized()
    # Compute e2 as cross product of e0 and e1.
    e2 = e0.cross(e1)
    return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)

Explanation: α \alpha of all residues within the local framework of each residueα Unit vector of carbon atom displacement. These local frames are calculated in the same way as the target structure.

The unit vector of the displacement of the alpha carbon atom of all residues within the local frame of each residue. These local frames are computed in the same way as for the target structure.

Calculation logic:

Rigid

unit_vector includes 3 components such as x, y, z, [1, 1102, 1102, 3], that is:

unit_vector


おすすめ

転載: blog.csdn.net/u012515223/article/details/134328447