文本分类(一) | (3) FastText

项目Github地址

本篇博客主要介绍基于FastText的文本分类算法的原理及实现细节。

目录

1. 分类原理

2.实现细节


1. 分类原理

输入样本是一系列整数索引(x_1,...,x_N),对应词典中相应的词,通过embedding 得到每个词对应的词向量。对样本(文本)中每个词对应的词向量求平均,再通过一个全连接层进行分类即可。 

2.实现细节


class FastText(BasicModule): #继承自BasicModule 其中封装了保存加载模型的接口,BasicModule继承自nn.Module

    def __init__(self, vocab_size,opt): #opt是config类的实例 里面包括所有模型超参数的配置
        super(FastText, self).__init__()


        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, opt.embed_size) #词嵌入矩阵 每一行代表词典中一个词对应的词向量;
        # 词嵌入矩阵可以随机初始化连同分类任务一起训练,也可以用预训练词向量初始化(冻结或微调)

        self.content_fc = nn.Sequential( #可以使用多个全连接层或batchnorm、dropout等 可以把这些模块用Sequential包装成一个大模块
            nn.Linear(opt.embed_size, opt.linear_hidden_size),
            nn.BatchNorm1d(opt.linear_hidden_size),
            nn.ReLU(inplace=True),
            #可以再加一个隐层
            # nn.Linear(opt.linear_hidden_size,opt.linear_hidden_size),
            # nn.BatchNorm1d(opt.linear_hidden_size),
            # nn.ReLU(inplace=True),
            #输出层
            nn.Linear(opt.linear_hidden_size, opt.classes)
        )


    def forward(self, inputs):
        #inputs(batch_size,seq_len)
        embeddings = self.embedding(inputs) # (batch_size, seq_len, embed_size)

        #对seq_len维取平均
        content = torch.mean(embeddings,dim=1) #(batch_size,1,embed_size)

        out = self.content_fc(content.squeeze(1)) #先压缩seq_len维 (batch_size,embed_size) 然后作为全连接层的输入
        #输出 (batch_size,classes)

        return out
发布了365 篇原创文章 · 获赞 714 · 访问量 13万+

猜你喜欢

转载自blog.csdn.net/sdu_hao/article/details/103596315
今日推荐