python3使用fasttext进行中文文本分类

本文完成在python3下,利用fasttext对中文文本进行分类。期间使用了jieba分词。

数据集

本文使用skdjfla在github上的头条新闻数据集.包括新闻数据382688条(只包含标题),共15个类别。

数据预处理

fastext包进行文本分类类似于sklearn,已经进行完整封装,只需要生成指定格式的文本文件即可以开始训练。文件格式如下(__labe__为类别标记)

2018 年 养羊 怎么样 ? __label__0
中国 第一所 私立 研究型 大学 成立 , 2023 年招 本科生 __label__3

在这里简单使用jieba进行分词,并对标签进行数值化(其实应该是没有必要的)

import jieba
from typing import List, Tuple, Dict

def load_dataset(filepath: str = 'data/头条分类数据.txt', sample: bool or int = False) -> Tuple:
	"""读取数据集"""
    texts, labels = [], []
    with open(filepath, 'r') as f:
        for i, line in enumerate(f):
            if sample and i == sample:
                break
            _, _, label, text, _ = line.split('_!_')
            text = ' '.join(jieba.cut(text))
            texts.append(text)
            labels.append(label)
    return texts, labels

# 1.处理数据
texts, labels = utils.load_dataset()
random.seed(0)
random.shuffle(texts)
random.seed(0)
random.shuffle(labels)
texts_train, texts_test, labels_train, labels_test = train_test_split(texts, labels, test_size=0.05, stratify=labels,
                                                                      random_state=0)
label_encoder = preprocessing.LabelEncoder()
labels_train = label_encoder.fit_transform(labels_train)
labels_test = label_encoder.transform(labels_test)
#写入到文本数据中
with open('data/fasttext.train.txt', 'w') as f:
    for i in range(len(texts_train)):
        f.write('%s __label__%d\n' % (texts_train[i], labels_train[i]))
with open('data/fasttext.test.txt', 'w') as f:
    for i in range(len(texts_test)):
        f.write('%s __label__%d\n' % (texts_test[i],labels_test[i]))

模型训练和测试

模型训练直接使用train_supervised函数即可,函数的测试使用model.test即可,但是只能获取指定指标,因此使用model.predict函数获取测试集文本的类别,然后使用sklearn计算F1值。注意,这里model.predict函数默认返回的是最大可能的类别和其概率,也可以返回前topK的类别和概率。

# 2.训练模型
model = fasttext.train_supervised('data/fasttext.train.txt',epoch=10)
print(model.words)
print(model.labels)
model.save_model("data/model_filename.bin")

# 验证模型
model = fasttext.load_model("data/model_filename.bin")
texts_test, labels_test = [], []
with open('data/fasttext.test.txt', 'r') as f:
    for line in f:
        *text, label = line.strip().split(' ')
        text = ' '.join(text)
        texts_test.append(text)
        labels_test.append(label)
        
label_encoder = preprocessing.LabelEncoder()
labels_test = label_encoder.fit_transform(labels_test)
predits = list(zip(*(model.predict(texts_test)[0])))[0]
predits = label_encoder.transform(predits)

score = metrics.f1_score(predits, labels_test, average='weighted')
print('weighted f1-score : %.03f' % score)

完整代码见github

猜你喜欢

转载自blog.csdn.net/lovoslbdy/article/details/104797667
今日推荐