lda代码 自己的文本

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/qq_32768743/article/details/89487546
# coding=utf-8
import os
import sys
import numpy as np
import matplotlib
import scipy
import matplotlib.pyplot as plt
from sklearn import feature_extraction
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import HashingVectorizer

if __name__ == "__main__":

    # 存储读取语料 一行预料为一个文档
    corpus = []
    for line in open('test.txt', 'r').readlines():
        # print line
        corpus.append(line.strip())
    # print corpus

    # 将文本中的词语转换为词频矩阵 矩阵元素a[i][j] 表示j词在i类文本下的词频
    vectorizer = CountVectorizer()
    print
    vectorizer

    X = vectorizer.fit_transform(corpus)
    analyze = vectorizer.build_analyzer()
    weight = X.toarray()

    print
    len(weight)
    print(weight[:5, :5])

    # LDA算法
    print
    'LDA:'
    import numpy as np
    import lda
    import lda.datasets

    model = lda.LDA(n_topics=2, n_iter=500, random_state=1)
    model.fit(np.asarray(weight))  # model.fit_transform(X) is also available
    topic_word = model.topic_word_  # model.components_ also works

    # 文档-主题(Document-Topic)分布
    doc_topic = model.doc_topic_
    print("type(doc_topic): {}".format(type(doc_topic)))
    print("shape: {}".format(doc_topic.shape))

    # 输出前10篇文章最可能的Topic
    label = []
    for n in range(10):
        topic_most_pr = doc_topic[n].argmax()
        label.append(topic_most_pr)
        print("doc: {} topic: {}".format(n, topic_most_pr))

    # 计算文档主题分布图
    import matplotlib.pyplot as plt

    f, ax = plt.subplots(6, 1, figsize=(8, 8), sharex=True)
    for i, k in enumerate([0, 1, 2, 3, 8, 9]):
        ax[i].stem(doc_topic[k, :], linefmt='r-',
                   markerfmt='ro', basefmt='w-')
        ax[i].set_xlim(-1, 2)  # x坐标下标
        ax[i].set_ylim(0, 1.2)  # y坐标下标
        ax[i].set_ylabel("Prob")
        ax[i].set_title("Document {}".format(k))
    ax[5].set_xlabel("Topic")
    plt.tight_layout()
    plt.show()

文本为

新春 备 年货 新年 联欢晚会
新春 节目单 春节 联欢晚会 红火
大盘 下跌 股市 散户
下跌 股市 赚钱
金猴 新春 红火 新年
新车 新年 年货 新春
股市 反弹 下跌
股市 散户 赚钱
新年 春节 联欢晚会
大盘 下跌 散户

在这里插入图片描述
参考:
https://blog.csdn.net/Eastmount/article/details/50891162

猜你喜欢

转载自blog.csdn.net/qq_32768743/article/details/89487546
LDA