I. Introduction dataset
Source: Today's headlines client
Data format is as follows:
6551700932705387022_!_101_!_news_culture_!_京城最值得你来场文化之旅的博物馆_!_保利集团,马未都,中国科学技术馆,博物馆,新中国
6552368441838272771_!_101_!_news_culture_!_发酵床的垫料种类有哪些?哪种更好?_!_
6552407965343678723_!_101_!_news_culture_!_上联:黄山黄河黄皮肤黄土高原。怎么对下联?_!_
6552332417753940238_!_101_!_news_culture_!_林徽因什么理由拒绝了徐志摩而选择梁思成为终身伴侣?_!_
6552475601595269390_!_101_!_news_culture_!_黄杨木是什么树?_!_
Each piece of data behavior to _! _ Split fields, from front to back are the news ID, classification code (see below), the category name (see below), News string (containing only the title), News Keywords
Classification code and name:
100 民生 故事 news_story
101 文化 文化 news_culture
102 娱乐 娱乐 news_entertainment
103 体育 体育 news_sports
104 财经 财经 news_finance
106 房产 房产 news_house
107 汽车 汽车 news_car
108 教育 教育 news_edu
109 科技 科技 news_tech
110 军事 军事 news_military
112 旅游 旅游 news_travel
113 国际 国际 news_world
114 证券 股票 stock
115 农业 三农 news_agriculture
116 电竞 游戏 news_game
github address: https: //github.com/fate233/toutiao-text-classfication-dataset
Data given in resource classification results:
Test Loss: 0.57, Test Acc: 83.81%
precision recall f1-score support
news_story 0.66 0.75 0.70 848
news_culture 0.57 0.83 0.68 1531
news_entertainment 0.86 0.86 0.86 8078
news_sports 0.94 0.91 0.92 7338
news_finance 0.59 0.67 0.63 1594
news_house 0.84 0.89 0.87 1478
news_car 0.92 0.90 0.91 6481
news_edu 0.71 0.86 0.77 1425
news_tech 0.85 0.84 0.85 6944
news_military 0.90 0.78 0.84 6174
news_travel 0.58 0.76 0.66 1287
news_world 0.72 0.69 0.70 3823
stock 0.00 0.00 0.00 53
news_agriculture 0.80 0.88 0.84 1701
news_game 0.92 0.87 0.89 6244
avg / total 0.85 0.84 0.84 54999
Here we have to use deeplearning4j to implement a convolution structure to classify the data set, see if I can get better results.
Second, the network can be used convolution reason text processing
CNN is ideal for processing image data, previous article "deeplearning4j-- convolution neural network identification code" introduced on CNN verification code identification. CNN will use this blog for text classification, before the start of what we first talk about the convolution operation intuitive in nature to do something Yes. Convolution operation can be seen as essentially dot product of two vectors, the two vectors with the larger the dot product, and after relu MaxPooling, extracted with essentially the same direction most convolution kernel structure, this "structure" is actually a number of lines on the picture.
CNN then the text can be used to deal with it? The answer is yes, the text after each word represented by a vector, arranged in turn, becomes a two-dimensional map, as shown below, to see the direction of the red arrow (that is, the direction of the text), two sentences with a after the figure shows, the same elements will appear, it can be treated with CNN.
Third, the text processing convolution structure
So, how to design the CNN network structure? :( papers address the following figure: https: //arxiv.org/abs/1408.5882)
important point:
1, the direction of movement of the convolution kernel must direction sentences
2, wherein each of the convolution kernel extract as a row vector of N
3, MaxPooling manipulation of objects is that each Feature Map, i.e. selecting a maximum value from a vector of each of the N rows
4, the maximum value of all the selected pick up, after several Fully Connected layer, classifies
Fourth, data preprocessing and term vectors
1, segmentation tools: HanLP
2, the data format processed as follows :( category code _! _ Words, wherein, between words separated by a space, _! _ Of delimiter)
Data preprocessing code is as follows:
public static void main(String[] args) throws Exception {
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
new FileInputStream(new File("/toutiao_cat_data/toutiao_cat_data.txt")), "UTF-8"));
OutputStreamWriter writerStream = new OutputStreamWriter(
new FileOutputStream("/toutiao_cat_data/toutiao_data_type_word.txt"), "UTF-8");
BufferedWriter writer = new BufferedWriter(writerStream);
String line = null;
long startTime = System.currentTimeMillis();
while ((line = bufferedReader.readLine()) != null) {
String[] array = line.split("_!_");
StringBuilder stringBuilder = new StringBuilder();
for (Term term : HanLP.segment(array[3])) {
if (stringBuilder.length() > 0) {
stringBuilder.append(" ");
}
stringBuilder.append(term.word.trim());
}
writer.write(Integer.parseInt(array[1].trim()) + "_!_" + stringBuilder.toString() + "\n");
}
writer.flush();
writer.close();
System.out.println(System.currentTimeMillis() - startTime);
bufferedReader.close();
}
Fifth, the word vector representation
1、one-hot
Orthogonal vectors to represent each word, said this does not reflect the relationship between words and words, then two sentences, in order to reuse the same convolution kernels, it must appear exactly the same words can, in fact, we ask the model can extrapolate, even similar structure can also be extracted, then word2vec can solve this problem.
2、word2vec
word2vec can fully consider the relationship between words and words, similar words, there must be some dimensions by more recent. Then it considers the relationship between the statement of the word, there are two training word2vec, skipgram and cbow, we use the following cbow to train word vector, the result will persist down, you get toutiao.vec file, the next change re-load the file to get the word vector representation, as follows:
String filePath = new ClassPathResource("toutiao_data_word.txt").getFile().getAbsolutePath();
SentenceIterator iter = new BasicLineIterator(filePath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
VocabCache<VocabWord> cache = new AbstractCache<>();
WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100)
.useAdaGrad(false).cache(cache).build();
log.info("Building model....");
Word2Vec vec = new Word2Vec.Builder()
.elementsLearningAlgorithm("org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW")
.minWordFrequency(0).iterations(1).epochs(20).layerSize(100).seed(42).windowSize(8).iterate(iter)
.tokenizerFactory(t).lookupTable(table).vocabCache(cache).build();
vec.fit();
WordVectorSerializer.writeWord2VecModel(vec, "/toutiao_cat_data/toutiao.vec");
Six, CNN network structure
CNN network structure is as follows:
Description:
1, cnn3, cnn4, cnn5, cnn6 convolution kernel of size (3, vectorSize), (4, vectorSize), (5, vectorSize), (6, vectorSize), steps 1, i.e. 3 respectively read, 4,5,6 word, feature extraction
2, cnn3-stride2, cnn4-stride2, cnn5-stride2, cnn6-stride2 size of the convolution kernel (3, vectorSize), (4, vectorSize), (5, vectorSize), (6, vectorSize), steps 2
3, the results of the two groups were combined convolution kernel convolution, and respectively merge1 merge2, are 4-dimensional tensor shape respectively (batchSize, depth1 + depth2 + depth3, height / 1,1), (batchSize, depth1 + depth2 + depth3, height / 2,1), in particular: convolution herein mode ConvolutionMode.Same
4、merge1、2分别经过MaxPooling,这里用的是GlobalPoolingLayer,和平台的Pooling层不同,这里会从指定维度中,取一个最大值,所以经过GlobalPoolingLayer之后,merge1、2分别变成2维张量,形状为(batchSize,depth1+depth2+depth3),那么GlobalPoolingLayer是如何求Max的呢?源码如下:
private INDArray activateHelperFullArray(INDArray inputArray, int[] poolDim) {
switch (poolingType) {
case MAX:
return inputArray.max(poolDim);
case AVG:
return inputArray.mean(poolDim);
case SUM:
return inputArray.sum(poolDim);
case PNORM:
//P norm: https://arxiv.org/pdf/1311.1780.pdf
//out = (1/N * sum( |in| ^ p) ) ^ (1/p)
int pnorm = layerConf().getPnorm();
INDArray abs = Transforms.abs(inputArray, true);
Transforms.pow(abs, pnorm, false);
INDArray pNorm = abs.sum(poolDim);
return Transforms.pow(pNorm, 1.0 / pnorm, false);
default:
throw new RuntimeException("Unknown or not supported pooling type: " + poolingType + " " + layerId());
}
}
5、两边GlobalPoolingLayer结果再接起来,丢给全连接网络,经过softmax分类器进行分类
6、fc层,用了0.5的dropout防止过拟合,在下面的代码中可以看到。
完整代码如下:
public class CnnSentenceClassificationTouTiao {
public static void main(String[] args) throws Exception {
List<String> trainLabelList = new ArrayList<>();// 训练集label
List<String> trainSentences = new ArrayList<>();// 训练集文本集合
List<String> testLabelList = new ArrayList<>();// 测试集label
List<String> testSentences = new ArrayList<>();//// 测试集文本集合
Map<String, List<String>> map = new HashMap<>();
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
new FileInputStream(new File("/toutiao_cat_data/toutiao_data_type_word.txt")), "UTF-8"));
String line = null;
int truncateReviewsToLength = 0;
Random random = new Random(123);
while ((line = bufferedReader.readLine()) != null) {
String[] array = line.split("_!_");
if (map.get(array[0]) == null) {
map.put(array[0], new ArrayList<String>());
}
map.get(array[0]).add(array[1]);// 将样本中所有数据,按照类别归类
int length = array[1].split(" ").length;
if (length > truncateReviewsToLength) {
truncateReviewsToLength = length;// 求样本中,句子的最大长度
}
}
bufferedReader.close();
for (Map.Entry<String, List<String>> entry : map.entrySet()) {
for (String sentence : entry.getValue()) {
if (random.nextInt() % 5 == 0) {// 每个类别抽取20%作为test集
testLabelList.add(entry.getKey());
testSentences.add(sentence);
} else {
trainLabelList.add(entry.getKey());
trainSentences.add(sentence);
}
}
}
int batchSize = 64;
int vectorSize = 100;
int nEpochs = 10;
int cnnLayerFeatureMaps = 50;
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(12345);
Nd4j.getMemoryManager().setAutoGcWindow(5000);
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU)
.activation(Activation.LEAKYRELU).updater(new Nesterovs(0.01, 0.9))
.convolutionMode(ConvolutionMode.Same).l2(0.0001).graphBuilder().addInputs("input")
.addLayer("cnn3",
new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn4",
new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn5",
new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn6",
new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn3-stride2",
new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn4-stride2",
new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn5-stride2",
new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn6-stride2",
new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addVertex("merge1", new MergeVertex(), "cnn3", "cnn4", "cnn5", "cnn6")
.addLayer("globalPool1", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
"merge1")
.addVertex("merge2", new MergeVertex(), "cnn3-stride2", "cnn4-stride2", "cnn5-stride2", "cnn6-stride2")
.addLayer("globalPool2", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
"merge2")
.addLayer("fc",
new DenseLayer.Builder().nOut(200).dropOut(0.5).activation(Activation.LEAKYRELU).build(),
"globalPool1", "globalPool2")
.addLayer("out",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(15).build(),
"fc")
.setOutputs("out").setInputTypes(InputType.convolutional(truncateReviewsToLength, vectorSize, 1))
.build();
ComputationGraph net = new ComputationGraph(config);
net.init();
System.out.println(net.summary());
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("/toutiao_cat_data/toutiao.vec");
System.out.println("Loading word vectors and creating DataSetIterators");
DataSetIterator trainIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, trainLabelList,
trainSentences, rng);
DataSetIterator testIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, testLabelList,
testSentences, rng);
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
net.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage, 20),
new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
// net.setListeners(new ScoreIterationListener(100),
// new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
net.fit(trainIter, nEpochs);
}
private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
List<String> lableList, List<String> sentences, Random rng) {
LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng);
return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors)
.minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false)
.build();
}
}
代码说明:
1、代码分两部分,第一部分是数据预处理,分出20%测试集、80%作为训练集
2、第二部分为网络的基本结构代码
网络参数详细如下:
===============================================================================================================================================
VertexName (VertexType) nIn,nOut TotalParams ParamsShape Vertex Inputs
===============================================================================================================================================
input (InputVertex) -,- - - -
cnn3 (ConvolutionLayer) 1,50 15050 W:{50,1,3,100}, b:{1,50} [input]
cnn4 (ConvolutionLayer) 1,50 20050 W:{50,1,4,100}, b:{1,50} [input]
cnn5 (ConvolutionLayer) 1,50 25050 W:{50,1,5,100}, b:{1,50} [input]
cnn6 (ConvolutionLayer) 1,50 30050 W:{50,1,6,100}, b:{1,50} [input]
cnn3-stride2 (ConvolutionLayer) 1,50 15050 W:{50,1,3,100}, b:{1,50} [input]
cnn4-stride2 (ConvolutionLayer) 1,50 20050 W:{50,1,4,100}, b:{1,50} [input]
cnn5-stride2 (ConvolutionLayer) 1,50 25050 W:{50,1,5,100}, b:{1,50} [input]
cnn6-stride2 (ConvolutionLayer) 1,50 30050 W:{50,1,6,100}, b:{1,50} [input]
merge1 (MergeVertex) -,- - - [cnn3, cnn4, cnn5, cnn6]
merge2 (MergeVertex) -,- - - [cnn3-stride2, cnn4-stride2, cnn5-stride2, cnn6-stride2]
globalPool1 (GlobalPoolingLayer) -,- 0 - [merge1]
globalPool2 (GlobalPoolingLayer) -,- 0 - [merge2]
fc-merge (MergeVertex) -,- - - [globalPool1, globalPool2]
fc (DenseLayer) 400,200 80200 W:{400,200}, b:{1,200} [fc-merge]
out (OutputLayer) 200,15 3015 W:{200,15}, b:{1,15} [fc]
-----------------------------------------------------------------------------------------------------------------------------------------------
Total Parameters: 263615
Trainable Parameters: 263615
Frozen Parameters: 0
===============================================================================================================================================
DL4J的UIServer界面如下,这里我给定的端口号为9001,打开web界面可以看到平均loss的详情,梯度更新的详情等
http://localhost:9001/train/overview
七、掩模
句子有长有短,CNN将如何处理呢?
处理的办法其实很暴力,将一个minibatch中的最长句子找到,new出最大长度的张量,多余值用掩模掩掉即可,废话不多说,直接上代码
if(sentencesAlongHeight){
featuresMask = Nd4j.create(currMinibatchSize, 1, maxLength, 1);
for (int i = 0; i < currMinibatchSize; i++) {
int sentenceLength = tokenizedSentences.get(i).getFirst().size();
if (sentenceLength >= maxLength) {
featuresMask.slice(i).assign(1.0);
} else {
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength), NDArrayIndex.point(0)).assign(1.0);
}
}
} else {
featuresMask = Nd4j.create(currMinibatchSize, 1, 1, maxLength);
for (int i = 0; i < currMinibatchSize; i++) {
int sentenceLength = tokenizedSentences.get(i).getFirst().size();
if (sentenceLength >= maxLength) {
featuresMask.slice(i).assign(1.0);
} else {
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
}
}
}
这里为什么有个if呢?生成句子张量的时候,可以任意指定句子的方向,可以沿着矩阵中height的方向,也可以是width的方向,方向不同,填掩模的那一维也就不同。
八、结果
运行了10个Epoch结果如下:
========================Evaluation Metrics========================
# of classes: 15
Accuracy: 0.8420
Precision: 0.8362 (1 class excluded from average)
Recall: 0.7783
F1 Score: 0.8346 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 15 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [12]
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
----------------------------------------------------------------------------
973 35 114 2 9 8 11 19 14 6 19 11 0 22 13 | 0 = 0
17 4636 250 37 51 16 14 151 47 29 232 36 0 82 44 | 1 = 1
103 176 6980 108 16 8 31 62 83 41 53 77 0 36 163 | 2 = 2
9 78 244 6692 37 9 52 59 33 27 57 54 0 10 96 | 3 = 3
7 52 36 31 4072 96 101 107 581 20 64 108 0 135 37 | 4 = 4
12 18 22 8 150 3061 27 36 53 2 100 16 0 56 2 | 5 = 5
17 38 71 26 94 13 6443 43 174 31 121 39 0 32 34 | 6 = 6
17 157 93 49 62 20 34 4793 85 14 58 36 0 49 31 | 7 = 7
1 45 71 21 436 30 195 138 7018 48 54 49 0 45 148 | 8 = 8
24 74 84 47 24 1 57 50 68 3963 45 431 0 9 65 | 9 = 9
9 165 90 21 40 37 61 40 42 21 3428 111 0 78 30 | 10 = 10
47 78 173 52 114 20 48 67 93 320 140 4097 0 48 29 | 11 = 11
0 0 0 0 60 0 1 0 5 0 0 0 0 0 0 | 12 = 12
35 105 31 6 139 37 34 61 79 11 153 35 0 3187 12 | 13 = 13
14 36 210 128 31 2 19 20 164 44 38 15 0 19 5183 | 14 = 14
平均准确率0.8420,比原资源中给定的结果略好,F1 score要略差一点,混淆矩阵中,有一个类别,无法被预测到,是因为样本中改类别数据量本身很少,难以抓到共性特征。这里参数如果精心调节一番,迭代更多次数,理论上会有更好的表现。
九、后记
读Deeplearning4j是一种享受,优雅的架构,清晰的逻辑,多种设计模式,扩展性强,将有后续博客,对dl4j源码进行剖析。
快乐源于分享。
此博客乃作者原创, 转载请注明出处