BiLSTM+Attention 实现关系抽取

一、关系抽取任务简介

1.1 任务定义

关系抽取(Relation Extraction, RE) 是从自然语言文本中识别出实体之间的语义关系的任务。给定一个句子和一个实体对,模型需要判断这两个实体之间是否存在某种预定义的语义关系。

输入

  • 一个句子,如:“Elon Musk is the CEO of Tesla.”
  • 一对实体,如 (“Elon Musk”, “Tesla”)

输出

  • 实体之间的语义关系,如:“CEO_of”

二、为什么使用 BiLSTM + Attention

2.1 BiLSTM 的作用

  • LSTM(长短期记忆网络) 是一种 RNN,用于捕捉文本中的时序依赖关系。
  • BiLSTM(双向 LSTM) 同时从 前向后向 编码输入序列的上下文信息,非常适合处理自然语言中的上下文语义理解。

2.2 Attention 的作用

  • 在每个时间步,BiLSTM 输出的是一个隐藏状态序列,不同位置的信息对关系的判断贡献不一致
  • Attention 机制可以帮助模型 聚焦于与关系判断最相关的词或片段,提升模型的解释性与性能。

三、BiLSTM + Attention 的模型结构

3.1 输入表示层

  • 对输入句子进行分词,并将每个词转换为向量(embedding)。
  • 可以使用:
    • 词嵌入(如 GloVe、Word2Vec)
    • 预训练模型的 embedding(如 BERT 的输出)
    • 也可加入实体位置信息,如相对距离特征。

3.2 BiLSTM 编码层

  • 输入嵌入序列送入 BiLSTM,得到每个时间步的 前向与后向隐藏状态拼接向量
    h i = [ h → i ; h ← i ] h_i = [\overrightarrow{h}_i ; \overleftarrow{h}_i] hi=[h i;h i]
  • 输出维度为 [batch_size, seq_len, 2 * hidden_size]

3.3 Attention 聚合层

  • 对 BiLSTM 的输出序列应用注意力机制:
    • 学习一个注意力权重向量 α \alpha α,用来加权求和所有时间步的输出:
      α i = exp ⁡ ( e i ) ∑ j exp ⁡ ( e j ) , e i = w T tanh ⁡ ( W h i + b ) \alpha_i = \frac{\exp(e_i)}{\sum_j \exp(e_j)},\quad e_i = w^T \tanh(W h_i + b) αi=jexp(ej)exp(ei),ei=wTtanh(Whi+b)
    • 最终聚合成一个向量 v = ∑ i α i h i v = \sum_i \alpha_i h_i v=iαihi,表示句子的全局表示。

3.4 分类层

  • 将注意力输出的上下文向量 v v v 送入一个 全连接层 + Softmax,输出每种关系的概率分布:
    y = softmax ( W v + b ) y = \text{softmax}(Wv + b) y=softmax(Wv+b)

四、训练方式

4.1 监督训练

  • 使用 交叉熵损失函数(Cross Entropy Loss) 作为训练目标:
    L = − ∑ i y i log ⁡ ( y ^ i ) \mathcal{L} = -\sum_{i} y_i \log(\hat{y}_i) L=iyilog(y^i)
    • y i y_i yi:真实标签的 one-hot 向量
    • y ^ i \hat{y}_i y^i:模型预测的概率分布

4.2 数据格式(适用于句子级关系分类)

每个训练样本包含:

  • 句子文本
  • 实体对
  • 实体之间的关系标签

可选增强方式:

  • 在输入中加入 实体位置标记(例如用特殊符号标记 e1、e2)
  • 使用 位置嵌入实体类型嵌入

五、模型示意图

             输入句子 + 实体对(如 "Elon Musk is the CEO of Tesla.")
                                ↓
                   Token Embedding + Entity Embedding
                                ↓
                           BiLSTM 编码
                                ↓
                    Attention 权重加权求和
                                ↓
                      上下文向量(全局表示)
                                ↓
                       全连接层 + Softmax 分类
                                ↓
                      预测关系(如 "CEO_of")

六、模型优势与适用场景

6.1 优势

特点 描述
上下文感知 BiLSTM 可以捕捉上下文中前后语义信息
聚焦关键部分 Attention 可以专注于对关系判断最相关的词
结构简单、可解释性强 Attention 权重可以解释模型的“注意力”来源

6.2 适用场景

  • 给定实体对的关系分类任务
  • 少数据任务(小样本)
  • 需要可解释性的关系抽取任务
  • 文本关系分类、事件三元组分类

七、模型局限与扩展方向

7.1 局限性

  • 不能同时抽取多个三元组(无法处理三元组重叠问题)
  • 输入固定长度(无法适应长文本关系跨句推理)
  • 仅适用于句子级分类,不能做端到端的联合抽取

7.2 可扩展方向

  • 加入 实体位置嵌入实体类型嵌入
  • BERT、ERNIE 等预训练模型结合(即 BiLSTM 替换为 Transformer)
  • 构建联合抽取模型(加入实体识别模块)

八、代码实现

1、数据样例(简化版)

sample_data = [
    {
    
    
        "text": "Elon Musk is the CEO of Tesla.",
        "entity1": "Elon Musk",
        "entity2": "Tesla",
        "relation": "CEO_of"
    },
    {
    
    
        "text": "Steve Jobs founded Apple.",
        "entity1": "Steve Jobs",
        "entity2": "Apple",
        "relation": "founded"
    }
]

relation2id = {
    
    
    "CEO_of": 0,
    "founded": 1,
    "no_relation": 2
}

2、模型结构(BiLSTM + Attention)

import torch
import torch.nn as nn

class BiLSTMAttentionRE(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_classes, vocab_size, padding_idx=0):
        super(BiLSTMAttentionRE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
        self.attention = nn.Linear(hidden_dim * 2, 1)
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, input_ids, attention_mask):
        embeds = self.embedding(input_ids)  # [batch, seq_len, embed_dim]
        lstm_out, _ = self.lstm(embeds)     # [batch, seq_len, hidden*2]
        
        attn_weights = torch.softmax(self.attention(lstm_out).squeeze(-1).masked_fill(attention_mask == 0, -1e9), dim=1)  # [batch, seq_len]
        context = torch.sum(lstm_out * attn_weights.unsqueeze(-1), dim=1)  # [batch, hidden*2]
        logits = self.classifier(context)
        return logits

3、词表与预处理函数

from collections import Counter
from transformers import BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def build_vocab(data):
    word_counter = Counter()
    for d in data:
        words = tokenizer.tokenize(d["text"])
        word_counter.update(words)
    vocab = {
    
    word: idx+2 for idx, (word, _) in enumerate(word_counter.most_common())}
    vocab["[PAD]"] = 0
    vocab["[UNK]"] = 1
    return vocab

def encode_sentence(text, vocab, max_len=32):
    tokens = tokenizer.tokenize(text)
    token_ids = [vocab.get(t, vocab["[UNK]"]) for t in tokens][:max_len]
    attention_mask = [1] * len(token_ids)
    # padding
    while len(token_ids) < max_len:
        token_ids.append(vocab["[PAD]"])
        attention_mask.append(0)
    return torch.tensor(token_ids), torch.tensor(attention_mask)

4、训练代码

# 准备数据
vocab = build_vocab(sample_data)
num_classes = len(relation2id)

X, attention_masks, y = [], [], []
for item in sample_data:
    x, attn = encode_sentence(item["text"], vocab)
    X.append(x)
    attention_masks.append(attn)
    y.append(torch.tensor(relation2id[item["relation"]]))

X = torch.stack(X)
attention_masks = torch.stack(attention_masks)
y = torch.stack(y)

# 模型、损失、优化器
model = BiLSTMAttentionRE(embedding_dim=128, hidden_dim=64, num_classes=num_classes, vocab_size=len(vocab)).to("cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# 训练过程
for epoch in range(10):
    model.train()
    logits = model(X, attention_masks)
    loss = loss_fn(logits, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {
      
      epoch+1}, Loss: {
      
      loss.item():.4f}")

5、预测代码

def predict(model, text):
    model.eval()
    x, attn = encode_sentence(text, vocab)
    x = x.unsqueeze(0)
    attn = attn.unsqueeze(0)
    with torch.no_grad():
        logits = model(x, attn)
        pred = torch.argmax(logits, dim=1).item()
    return pred

test_text = "Elon Musk is the CEO of Tesla."
pred_id = predict(model, test_text)
id2rel = {
    
    v: k for k, v in relation2id.items()}
print("预测关系:", id2rel[pred_id])

八、总结

模块 功能 关键技术
输入编码 将句子和实体转为向量 Token Embedding
编码层 捕捉上下文 BiLSTM
聚合层 提取关键语义 Attention
输出层 预测关系标签 全连接层 + Softmax
训练方式 监督学习 交叉熵损失

BiLSTM + Attention 是关系分类任务中经典的结构之一,具有良好的上下文建模能力和较强的可解释性,是从基础模型到复杂深度关系抽取系统的重要起点。