机器学习赛事-科大讯飞基于论文摘要的文本分类与关键词抽取挑战赛

科大讯飞 基于论文摘要的文本分类与关键词抽取挑战赛

一、赛事背景

医学领域的文献库中蕴含了丰富的疾病诊断和治疗信息,如何高效地从海量文献中提取关键信息,进行疾病诊断和治疗推荐,对于临床医生和研究人员具有重要意义。

二、赛事任务

本任务分为两个子任务:
机器通过对论文摘要等信息的理解,判断该论文是否属于医学领域的文献。
提取出该论文关键词。

传统NLP

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.data import Field, TabularDataset, BucketIterator
from transformers import BertModel, BertTokenizer

# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# 定义分类模型
class PaperClassifier(nn.Module):
    def __init__(self):
        super(PaperClassifier, self).__init__()
        self.bert = bert_model
        self.fc = nn.Linear(768, 2)  # 768是BERT模型的隐藏层大小
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, text):
        _, pooled_output = self.bert(text)
        output = self.fc(pooled_output)
        output = self.softmax(output)
        return output

# 定义关键词提取模型
class KeywordExtractor(nn.Module):
    def __init__(self):
        super(KeywordExtractor, self).__init__()
        self.bert = bert_model
        self.fc = nn.Linear(768, 1)  # 768是BERT模型的隐藏层大小
        
    def forward(self, text):
        _, pooled_output = self.bert(text)
        output = self.fc(pooled_output)
        return output

# 预处理数据
def preprocess_data():
    # 定义字段
    TEXT = Field(sequential=True, tokenize=tokenizer.tokenize, lower=True, include_lengths=True)
    LABEL = Field(sequential=False, use_vocab=False)
    KEYWORDS = Field(sequential=True, tokenize='spacy', lower=True)
    
    # 加载数据集
    train_data, test_data = TabularDataset.splits(
        path='data',
        train='train.csv',
        test='test.csv',
        format='csv',
        fields=[('text', TEXT), ('label', LABEL), ('keywords', KEYWORDS)]
    )
    
    # 构建词汇表
    TEXT.build_vocab(train_data)
    KEYWORDS.build_vocab(train_data)
    
    # 创建迭代器
    train_iterator, test_iterator = BucketIterator.splits(
        (train_data, test_data),
        batch_size=32,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True
    )
    
    return train_iterator, test_iterator

# 训练模型
def train_model(model, train_iterator):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    classifier_criterion = nn.CrossEntropyLoss()
    extractor_criterion = nn.BCEWithLogitsLoss()
    
    model.train()
    
    for epoch in range(10):
        for batch in train_iterator:
            text, text_lengths = batch.text
            labels = batch.label
            keywords = torch.tensor([getattr(batch, 'keywords')]).squeeze().T.float()
            
            optimizer.zero_grad()
            class_output, extract_output = model(text)
            classifier_loss = classifier_criterion(class_output, labels)
            extractor_loss = extractor_criterion(extract_output, keywords)
            loss = classifier_loss + extractor_loss
            loss.backward()
            optimizer.step()

# 使用模型进行预测
def predict(model, sentence):
    model.eval()
    
    indexed_tokens = tokenizer.encode(sentence, add_special_tokens=True)
    tokens_tensor = torch.tensor([indexed_tokens])
    
    with torch.no_grad():
        class_output, extract_output = model(tokens_tensor)
        predicted_label = torch.argmax(class_output, dim=1)
        predicted_keywords = [token for token, score in zip(tokenizer.tokenize(sentence), extract_output[0]) if score > 0]
        
    return predicted_label.item(), predicted_keywords

# 主函数
if __name__ == "__main__":
    train_iterator, _ = preprocess_data()
    
    classifier_model = PaperClassifier()
    keyword_model = KeywordExtractor()
    model = nn.ModuleList([classifier_model, keyword_model])
    
    train_model(model, train_iterator)
    
    example_sentence = "This is a medical paper about cancer detection."
    predicted_label, predicted_keywords = predict(model, example_sentence)
    
    if predicted_label == 1:
        print("1")
    else:
        print("0")
    
    print("提取的关键词:", predicted_keywords)

猜你喜欢

转载自blog.csdn.net/weixin_42452716/article/details/131426764