Deep learning practice 26-(Pytorch) build TextCNN to realize the task of multi-label text classification

Hello everyone, I am Weixue AI. Today I will introduce you to deep learning practice 26-(Pytorch) to build TextCNN to achieve multi-label text classification. TextCNN is a deep learning model for text classification. It is based on convolutional neural networks. Network (Convolutional Neural Networks, CNN) implementation. The main idea of ​​TextCNN is to use convolution operation to extract useful features from text, and use these features to predict the category of text.

TextCNN regards text as a one-dimensional time series data, and embeds each word into a vector space to form a sequence of word vectors. Then, TextCNN extracts key features by stacking some convolutional and pooling layers and converts them into a fixed-size vector. Finally, this vector will be fed to a fully connected layer for classification. The advantage of TextCNN is that it can capture local and global features in text very effectively, thus improving classification accuracy. In addition, the training speed of TextCNN is relatively fast and has good scalability.

TextCNN does multi-label classification

1. Library package import

import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from collections import Counter

 2. Define parameters

max_length = 20
batch_size = 32
embedding_dim = 100
num_filters = 100
filter_sizes = [2, 3, 4]
num_classes = 4
learning_rate = 0.001
num_epochs = 2000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. Dataset processing function


def load_data(file_path):
    df = pd.read_csv(file_path,encoding='gbk')
    texts = df['text'].tolist()
    labels = df['label'].apply(lambda x: x.split("-")).tolist()
    return texts, labels

def preprocess_text(text):
    text = re.sub(r'[^\w\s]', '', text)
    return text.strip().lower().split()

def build_vocab(texts, max_size=10000):
    word_counts = Counter()
    for text in texts:
        word_counts.update(preprocess_text(text))
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for i, (word, count) in enumerate(word_counts.most_common(max_size - 2)):
        vocab[word] = i + 2
    return vocab

def encode_text(text, vocab):
    tokens = preprocess_text(text)
    return [vocab.get(token, vocab["<UNK>"]) for token in tokens]

def pad_text(encoded_text, max_length):
    return encoded_text[:max_length] + [0] * max(0, max_length - len(encoded_text))

def encode_label(labels, label_set):
    encoded_labels = []
    for label in labels:
        encoded_label = [0] * len(label_set)
        for l in label:
            if l in label_set:
                encoded_label[label_set.index(l)] = 1
        encoded_labels.append(encoded_label)
    return encoded_labels

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        return torch.tensor(self.texts[index], dtype=torch.long), torch.tensor(self.labels[index], dtype=torch.float32)

texts, labels = load_data("data_qa.csv")
vocab = build_vocab(texts)
label_set = ["人工智能", "卷积神经网络", "大数据",'ChatGPT']

encoded_texts = [pad_text(encode_text(text, vocab), max_length) for text in texts]
encoded_labels = encode_label(labels, label_set)

X_train, X_test, y_train, y_test = train_test_split(encoded_texts, encoded_labels, test_size=0.2, random_state=42)
#print(X_train,y_train)

train_dataset = TextDataset(X_train, y_train)
test_dataset = TextDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Sample dataset:

text label
How Artificial Intelligence Affects Import and Export Trade——An Empirical Test Based on National Data artificial intelligence
Generative Artificial Intelligence——ChatGPT's Transformative Impact, Risk Challenges and Coping Strategies Artificial Intelligence-ChatGPT
Research on the relationship between artificial intelligence and the free and comprehensive development of human beings——Based on Marx's thought of labor liberation artificial intelligence
A Study on Influencing Factors of Middle School Students' Artificial Intelligence Technology Use Continuous Behavior Intention  artificial intelligence
Discussion on the application of artificial intelligence technology in the field of aerospace equipment  artificial intelligence
Ethical Reflections on AI-Enabled Education  artificial intelligence
The Myth of Artificial Intelligence: Distinction between ChatGPT and the Transcendent Digital Labor "Subject"  Artificial Intelligence-ChatGPT
Challenges and Opportunities of Artificial Intelligence (ChatGPT) to Postgraduate Education of Social Sciences  Artificial Intelligence-ChatGPT
The Realistic Picture of Artificial Intelligence Boosting Educational Reform——Analysis of Teachers' Coping Strategies for ChatGPT  Artificial Intelligence-ChatGPT
Smart Admission and the Death of Democracy: Risks and Challenges of Democratic Politics in the Age of Artificial Intelligence  artificial intelligence
Analysis and Enlightenment of the Research Status of Artificial Intelligence Writing in China  artificial intelligence
AI Regulation: Theory, Models and Trends  artificial intelligence
Written talk on "Application and regulation of the new generation of artificial intelligence technology ChatGPT"  Artificial Intelligence-ChatGPT
The economic and social impact of ChatGPT's new generation of artificial intelligence technology development  Artificial Intelligence-ChatGPT
ChatGPT Empowers Labor Education's Picture Presentation and Its Practice Strategies  Artificial Intelligence-ChatGPT
Artificial Intelligence Chatbot—Based on ChatGPT, Microsoft Bing Perspective Analysis  Artificial Intelligence-ChatGPT
The Biden Administration’s Crackdown on China’s Artificial Intelligence Industry and China’s Response  artificial intelligence
Research on the application of artificial intelligence technology in modern agricultural machinery artificial intelligence
Research on the impact of artificial intelligence on China's manufacturing innovation—evidence from the application of industrial robots  artificial intelligence
Application of artificial intelligence technology in electronic product design artificial intelligence
Intelligent content generation such as ChatGPT and the intelligent transformation faced by the news publishing industry Artificial Intelligence-ChatGPT
Research on intelligent image recognition and classification of crops based on convolutional neural network Artificial Intelligence - Convolutional Neural Network
Research on Improved Method of Image Classification Based on Convolutional Neural Network Artificial Intelligence - Convolutional Neural Network

Multiple labels are set here, and multiple labels are separated by "-" symbols.

4. Build the model

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, num_classes, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([nn.Conv2d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x= x.unsqueeze(1)
        x = [torch.relu(conv(x)).squeeze(3) for conv in self.convs]
        x = [torch.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
        x = torch.cat(x, 1)
        x = self.dropout(x)
        logits = self.fc(x)
        return torch.sigmoid(logits)

5. Model training

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct_preds = 0  # 记录正确预测的数量
    total_preds = 0  # 记录总的预测数量
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # 计算正确预测的数量
        predicted_labels = torch.argmax(outputs, dim=1)
        targets = torch.argmax(targets, dim=1)

        correct_preds += (predicted_labels == targets).sum().item()
        total_preds += len(targets)

    accuracy = correct_preds / total_preds  # 计算准确率
    return running_loss / len(dataloader), accuracy  # 返回平均损失和准确率

def evaluate(model, dataloader, device):
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for inputs, target in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds.extend(outputs.cpu().numpy())
            targets.extend(target.numpy())
    return np.array(preds), np.array(targets)

def calculate_metrics(preds, targets, threshold=0.5):
    preds = (preds > threshold).astype(int)
    f1 = f1_score(targets, preds, average="micro")
    precision = precision_score(targets, preds, average="micro")
    recall = recall_score(targets, preds, average="micro")
    return {"f1": f1, "precision": precision, "recall": recall}

model = TextCNN(len(vocab), embedding_dim, num_filters, filter_sizes, num_classes).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    if epoch % 20==0:
        train_loss,accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
        print(f"Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {accuracy:.4f}")

        preds, targets = evaluate(model, test_loader, device)
        metrics = calculate_metrics(preds, targets)
        print(f"Epoch: {epoch + 1}, F1: {metrics['f1']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}")
...
Epoch: 1821, Train Loss: 0.0055, Train Accuracy: 0.8837
Epoch: 1821, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1841, Train Loss: 0.0064, Train Accuracy: 0.9070
Epoch: 1841, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1861, Train Loss: 0.0047, Train Accuracy: 0.8837
Epoch: 1861, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1881, Train Loss: 0.0058, Train Accuracy: 0.8605
Epoch: 1881, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1901, Train Loss: 0.0064, Train Accuracy: 0.8488
Epoch: 1901, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1921, Train Loss: 0.0062, Train Accuracy: 0.8140
Epoch: 1921, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1941, Train Loss: 0.0059, Train Accuracy: 0.8953
Epoch: 1941, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1961, Train Loss: 0.0053, Train Accuracy: 0.8488
Epoch: 1961, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1981, Train Loss: 0.0055, Train Accuracy: 0.8488
Epoch: 1981, F1: 0.9429, Precision: 0.9429, Recall: 0.9429

You can use your own data set for training, just modify it according to the format

Guess you like

Origin blog.csdn.net/weixin_42878111/article/details/130294704