PSP - 蛋白质复合物结构预测 Template Pair 特征 Mask 可视化

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134333419

在蛋白质复合物结构预测中,在 TemplatePairEmbedderMultimer 层中 ,构建 Template Pair 特征的源码,即:

  • 将特征 template_dgrampseudo_beta_mask_2daatype_one_hotbackbone_mask_2dunit_vector(x/y/z) 特征,通过 linear 层累加到一起。
  • 其中,都需要使用 multichain_mask_2d 进行固定掩码,选择单链区域。
  • 输出维度:([1, 1102, 1102, 64]),linear层的输出 c_out 维度是 64。

源码如下:

def forward(
    self,
    template_dgram: torch.Tensor,
    aatype_one_hot: torch.Tensor,
    query_embedding: torch.Tensor,
    pseudo_beta_mask: torch.Tensor,
    backbone_mask: torch.Tensor,
    multichain_mask_2d: torch.Tensor,
    unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
    act = 0.0

    pseudo_beta_mask_2d = (
        pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
    )

    pseudo_beta_mask_2d = pseudo_beta_mask_2d * multichain_mask_2d
    template_dgram = template_dgram * pseudo_beta_mask_2d[..., None]

    act += self.dgram_linear(template_dgram)
    act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])

    aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
    act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
    act += self.aatype_linear_2(aatype_one_hot[..., None, :])

    backbone_mask_2d = backbone_mask[..., None] * backbone_mask[..., None, :]
    backbone_mask_2d = backbone_mask_2d * multichain_mask_2d

    x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
    act += self.x_linear(x[..., None])
    act += self.y_linear(y[..., None])
    act += self.z_linear(z[..., None])

    act += self.backbone_mask_linear(backbone_mask_2d[..., None])

    query_embedding = self.query_embedding_layer_norm(query_embedding)
    act += self.query_embedding_linear(query_embedding)

    return act

template_dgram 特征:

template_dgram

template_dgram 特征与 multichain_mask_2d

template_dgram mask

backbone_mask_2d 特征:

backbone_mask_2d

backbone_mask_2d 特征与 multichain_mask_2d

backbone_mask_2d mask

写入特征,即:

tmp_dict = dict()
tmp_dict["pseudo_beta_mask_2d_prev"] = pseudo_beta_mask_2d.cpu().numpy()
tmp_dict["pseudo_beta_mask_2d_post"] = pseudo_beta_mask_2d.cpu().numpy()
tmp_dict["template_dgram_post"] = template_dgram.cpu().numpy()
tmp_dict["backbone_mask_2d_prev"] = backbone_mask_2d.cpu().numpy()
tmp_dict["backbone_mask_2d_post"] = backbone_mask_2d.cpu().numpy()

import pickle
with open("template_pair_embedder_multimer.pkl", "wb") as f:
    pickle.dump(tmp_dict, f)
logger.info(f"[CL] saved template_pair_embedder_multimer!")

读取特征,即:

def load_tensor_dict(input_path):
    """
    加载特征文件
    ['template_dgram', 'z', 'pseudo_beta_mask', 'backbone_mask', 'multichain_mask_2d',
    'unit_vector_x', 'unit_vector_y', 'unit_vector_z']
    """
    import pickle
    with open(input_path, "rb") as f:
        obj = pickle.load(f)
    print(f"[Info] feat_dict: {
      
      obj.keys()}")
    return obj
  
def process_template_pair_embedder_multimer_dict(feat_dict, output_dir):
    print(f"[Info] feat_dict.keys: {
      
      feat_dict.keys()}")
    draw_tensor_2d(feat_dict["pseudo_beta_mask_2d_prev"], os.path.join(output_dir, "pseudo_beta_mask_2d_prev.png"))
    draw_tensor_2d(feat_dict["pseudo_beta_mask_2d_post"], os.path.join(output_dir, "pseudo_beta_mask_2d_prev.png"))
    draw_template_dgram(feat_dict["template_dgram_post"], os.path.join(output_dir, "template_dgram_post.png"))
    draw_tensor_2d(feat_dict["backbone_mask_2d_prev"], os.path.join(output_dir, "backbone_mask_2d_prev.png"))
    draw_tensor_2d(feat_dict["backbone_mask_2d_post"], os.path.join(output_dir, "backbone_mask_2d_post.png"))
    
def draw_tensor_2d(feat, output_path):
    """
    backbone_mask: torch.Size([1, 1102])
    """
    feat = np.squeeze(feat)
    f, ax_arr = plt.subplots(1, 1, figsize=(8, 5))
    im = ax_arr.imshow(feat)
    f.colorbar(im, ax=ax_arr)
    plt.savefig(output_path, bbox_inches='tight', format='png')
    plt.show()
    
def draw_template_dgram(feat, output_path):
    """
    template_dgram: torch.Size([1, 1102, 1102, 39])
    """
    f, ax_arr = plt.subplots(6, 7, figsize=(24, 15))
    ax_arr = ax_arr.flatten()
    feat = np.squeeze(feat)
    print(f"[Info] feat: {
      
      feat.shape}")
    for i in range(0, 42):
        if i <= 38:
            im = ax_arr[i].imshow(feat[:, :, i], interpolation='none')
            f.colorbar(im, ax=ax_arr[i])
        else:
            ax_arr[i].set_axis_off()

    plt.savefig(output_path, bbox_inches='tight', format='png')
    plt.show()

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/134333419
psp
今日推荐