Uni-Mol微调(finetune)过程

# 是的,上述代码可以看作是一种finetune过程。
# 该代码定义了一个分类头(ClassificationHead)用于对预训练模型进行微调,在微调过程中,使用新的训练数据对分类头进行训练,以实现新任务的分类目标。
# 具体地,register_classification_head方法用于注册一个分类头到预训练模型中,而build_model方法则是用于构建完整的微调模型。
# 在构建微调模型时,首先使用预训练模型的权重初始化微调模型,然后将分类头添加进去,使得微调模型可以同时实现预测任务和分类任务。

finetune使用的模块

class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim,
        inner_dim,
        num_classes,
        activation_fn,
        pooler_dropout,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.activation_fn = utils.get_activation_fn(activation_fn)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = self.activation_fn(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

 预训练


def register_classification_head(
        self, name, num_classes=None, inner_dim=None, **kwargs
):
    """Register a classification head."""
    if name in self.classification_heads:
        prev_num_classes = self.classification_heads[name].out_proj.out_features
        prev_inner_dim = self.classification_heads[name].dense.out_features
        if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
            logger.warning(
                're-registering head "{}" with num_classes {} (prev: {}) '
                "and inner_dim {} (prev: {})".format(
                    name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
                )
            )
    self.classification_heads[name] = ClassificationHead(
        input_dim=self.args.encoder_embed_dim,
        inner_dim=inner_dim or self.args.encoder_embed_dim,
        num_classes=num_classes,
        activation_fn=self.args.pooler_activation_fn,
        pooler_dropout=self.args.pooler_dropout,
    )


def build_model(self, args):
    from unicore import models

    model = models.build_model(args, self)
    model.register_classification_head(
        self.args.classification_head_name,
        num_classes=self.args.num_classes,
    )
    return model
 
 

猜你喜欢

转载自blog.csdn.net/weixin_43135178/article/details/130201600
今日推荐