利用Fast-Text进行中文文本分类

第一步:安装fasttext

可以参考:https://github.com/facebookresearch/fastText#building-fasttext

liunx版本下操作:
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ pip install .

安装成功后的导入:
新建test.py文件,写入:
import fastText.FastText as fasttext(可能会瞟红线)

保存后退出并运行:
python3 test.py
没报错说明安装成功

第二步:准备数据集

  • 我这里用的是清华的新闻数据集(由于完整数据集较大,这里只取部分数据)
  • 数据链接:点击获取网盘数据 提取码:byoi(data.txt为数据集,stopwords.txt为停用词)
  • 下载好后的数据格式为:
    微信截图_20190513160240.png
  • 对应的标签分别为(由于只是用小部分数据,所以data.txt只包含部分标签):
mapper_tag = {
    '财经': 'Finance',
    '彩票': 'Lottery',
    '房产': 'Property',
    '股票': 'Shares',
    '家居': 'Furnishing',
    '教育': 'Education',
    '科技': 'Technology',
    '社会': 'Sociology',
    '时尚': 'Fashion',
    '时政': 'Affairs',
    '体育': 'Sports',
    '星座': 'Constellation',
    '游戏': 'Game',
    '娱乐': 'Entertainment'
}

第三步:数据预处理

  • 由于data.txt已经经过了分词和去停用词的处理,所以这里只需要对数据进行切割为训练集和测试集即可。
  • 分词和去停用词的工具代码(运行时不需要执行此部分代码):
import re
from types import MethodType, FunctionType

import jieba


def clean_txt(raw):
    fil = re.compile(r"[^0-9a-zA-Z\u4e00-\u9fa5]+")
    return fil.sub(' ', raw)

def seg(sentence, sw, apply=None):
    if isinstance(apply, FunctionType) or isinstance(apply, MethodType):
        sentence = apply(sentence)
    return ' '.join([i for i in jieba.cut(sentence) if i.strip() and i not in sw])
    
def stop_words():
    with open('stop_words.txt', 'r', encoding='utf-8') as swf:
        return [line.strip() for line in swf]

# 对某个sentence进行处理:
content = '上海天然橡胶期价周三再创年内新高,主力合约突破21000元/吨重要关口。'
res = seg(content.lower().replace('\n', ''), stop_words(), apply=clean_txt)
    
  • 切割数据(这里我是先将txt文件转换成csv文件,方便后面的计算)
from random import shuffle

import pandas as pd

class _MD(object):
    mapper = {
        str: '',
        int: 0,
        list: list,
        dict: dict,
        set: set,
        bool: False,
        float: .0
    }

    def __init__(self, obj, default=None):
        self.dict = {}
        assert obj in self.mapper, \
            'got a error type'
        self.t = obj
        if default is None:
            return
        assert isinstance(default, obj), \
            f'default ({default}) must be {obj}'
        self.v = default

    def __setitem__(self, key, value):
        self.dict[key] = value


    def __getitem__(self, item):
        if item not in self.dict and hasattr(self, 'v'):
            self.dict[item] = self.v
            return self.v
        elif item not in self.dict:
            if callable(self.mapper[self.t]):
                self.dict[item] = self.mapper[self.t]()
            else:
                self.dict[item] = self.mapper[self.t]
            return self.dict[item]
        return self.dict[item]


def defaultdict(obj, default=None):
    return _MD(obj, default)


class TransformData(object):
    def to_csv(self, handler, output, index=False):
        dd = defaultdict(list)
        for line in handler:
            label, content = line.split(',', 1)
            dd[label.strip('__label__').strip()].append(content.strip())

        df = pd.DataFrame()
        for key in dd.dict:
            col = pd.Series(dd[key], name=key)
            df = pd.concat([df, col], axis=1)
        return df.to_csv(output, index=index, encoding='utf-8')


def split_train_test(source, auth_data=False):
    if not auth_data:
        train_proportion = 0.8
    else:
        train_proportion = 0.98

    basename = source.rsplit('.', 1)[0]
    train_file = basename + '_train.txt'
    test_file = basename + '_test.txt'

    handel = pd.read_csv(source, index_col=False, low_memory=False)
    train_data_set = []
    test_data_set = []
    for head in list(handel.head()):
        train_num = int(handel[head].dropna().__len__() * train_proportion)
        sub_list = [f'__label__{head} , {item.strip()}\n' for item in handel[head].dropna().tolist()]
        train_data_set.extend(sub_list[:train_num])
        test_data_set.extend(sub_list[train_num:])
    shuffle(train_data_set)
    shuffle(test_data_set)

    with open(train_file, 'w', encoding='utf-8') as trainf,\
        open(test_file, 'w', encoding='utf-8') as testf:
        for tds in train_data_set:
            trainf.write(tds)
        for i in test_data_set:
            testf.write(i)

    return train_file, test_file

# 转化成csv
td = TransformData()
handler = open('data.txt')
td.to_csv(handler, 'data.csv')
handler.close()

# 将csv文件切割,会生成两个文件(data_train.txt和data_test.txt)
train_file, test_file = split_train_test('data.csv', auth_data=True)

第四步:训练模型

import fastText.FastText as fasttext

def train_model(ipt=None, opt=None, model='', dim=100, epoch=5, lr=0.1, loss='softmax'):
    np.set_printoptions(suppress=True)
    if os.path.isfile(model):
        classifier = fasttext.load_model(model)
    else:
        classifier = fasttext.train_supervised(ipt, label='__label__', dim=dim, epoch=epoch,
                                         lr=lr, wordNgrams=2, loss=loss)
        """
          训练一个监督模型, 返回一个模型对象
          
          @param input:           训练数据文件路径
          @param lr:              学习率
          @param dim:             向量维度
          @param ws:              cbow模型时使用
          @param epoch:           次数
          @param minCount:        词频阈值, 小于该值在初始化时会过滤掉
          @param minCountLabel:   类别阈值,类别小于该值初始化时会过滤掉
          @param minn:            构造subword时最小char个数
          @param maxn:            构造subword时最大char个数
          @param neg:             负采样
          @param wordNgrams:      n-gram个数
          @param loss:            损失函数类型, softmax, ns: 负采样, hs: 分层softmax
          @param bucket:          词扩充大小, [A, B]: A语料中包含的词向量, B不在语料中的词向量
          @param thread:          线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出
          @param lrUpdateRate:    学习率更新
          @param t:               负采样阈值
          @param label:           类别前缀
          @param verbose:         ??
          @param pretrainedVectors: 预训练的词向量文件路径, 如果word出现在文件夹中初始化不再随机
          @return model object
        """
        classifier.save_model(opt)
    return classifier

dim = 100
lr = 5
epoch = 5
model = f'data_dim{str(dim)}_lr0{str(lr)}_iter{str(epoch)}.model'

classifier = train_model(ipt='data_train.txt',
                         opt=model,
                         model=model,
                         dim=dim, epoch=epoch, lr=0.5
                         )

result = classifier.test('data_test.txt')
print(result)

# 整体的结果为(测试数据量,precision,recall):
(9885, 0.9740010116337886, 0.9740010116337886)
  • 可以看出结果相当高,由于上面是将整体作为测试,fasttext只给出整体的结果,precision和recall是相同的,下面我们测试每个标签的precision、recall和F1值。
def cal_precision_and_recall(file='data_test.txt'):
    precision = defaultdict(int, 1)
    recall = defaultdict(int, 1)
    total = defaultdict(int, 1)
    with open(file) as f:
        for line in f:
            label, content = line.split(',', 1)
            total[label.strip().strip('__label__')] += 1
            labels2 = classifier.predict([seg(sentence=content.strip(), sw='', apply=clean_txt)])
            pre_label, sim = labels2[0][0][0], labels2[1][0][0]
            recall[pre_label.strip().strip('__label__')] += 1
    
            if label.strip() == pre_label.strip():
                precision[label.strip().strip('__label__')] += 1
    
    print('precision', precision.dict)
    print('recall', recall.dict)
    print('total', total.dict)
    for sub in precision.dict:
        pre = precision[sub] / total[sub]
        rec =  precision[sub] / recall[sub]
        F1 = (2 * pre * rec) / (pre + rec)
        print(f"{sub.strip('__label__')}  precision: {str(pre)}  recall: {str(rec)}  F1: {str(F1)}")
  • 结果:
precision {'Technology': 983, 'Education': 972, 'Shares': 988, 'Affairs': 975, 'Entertainment': 991, 'Financ': 982, 'Furnishing': 975, 'Gam': 841, 'Sociology': 946, 'Sports': 978}
recall {'Technology': 992, 'Education': 1013, 'Shares': 1007, 'Affairs': 995, 'Entertainment': 1022, 'Financ': 1001, 'Furnishing': 997, 'Gam': 854, 'Sociology': 1025, 'Sports': 989}
total {'Technology': 1001, 'Education': 1001, 'Shares': 1001, 'Affairs': 1001, 'Entertainment': 1001, 'Financ': 1001, 'Furnishing': 1001, 'Gam': 876, 'Sociology': 1001, 'Sports': 1001, 'Property': 11}

Technology  precision: 0.9820179820179821  recall: 0.9909274193548387  F1: 0.9864525840441545
Education  precision: 0.971028971028971  recall: 0.9595261599210266  F1: 0.9652432969215492
Shares  precision: 0.987012987012987  recall: 0.9811320754716981  F1: 0.9840637450199202
Affairs  precision: 0.974025974025974  recall: 0.9798994974874372  F1: 0.9769539078156312
Entertainment  precision: 0.99000999000999  recall: 0.9696673189823874  F1: 0.9797330696984675
Financ  precision: 0.981018981018981  recall: 0.981018981018981  F1: 0.981018981018981
Furnishing  precision: 0.974025974025974  recall: 0.9779338014042126  F1: 0.975975975975976
Gam  precision: 0.9600456621004566  recall: 0.9847775175644028  F1: 0.9722543352601155
Sociology  precision: 0.945054945054945  recall: 0.9229268292682927  F1: 0.9338598223099703
Sports  precision: 0.977022977022977  recall: 0.9888776541961577  F1: 0.9829145728643216

可以看出结果非常可观,fasttext很强大…

整合后的代码:

def main(source):
    basename = source.rsplit('.', 1)[0]
    csv_file = basename + '.csv'

    td = TransformData()
    handler = open(source)
    td.to_csv(handler, csv_file)
    handler.close()

    train_file, test_file = split_train_test(csv_file)

    dim = 100
    lr = 5
    epoch = 5
    model = f'data/data_dim{str(dim)}_lr0{str(lr)}_iter{str(epoch)}.model'

    classifier = train_model(ipt=train_file,
                             opt=model,
                             model=model,
                             dim=dim, epoch=epoch, lr=0.5
                             )

    result = classifier.test(test_file)
    print(result)

    cal_precision_and_recall(test_file)


if __name__ == '__main__':
    main('data.txt')

猜你喜欢

转载自blog.csdn.net/weixin_43977375/article/details/90200837
今日推荐