一、关系抽取任务简介
1.1 任务定义
关系抽取(Relation Extraction, RE) 是从自然语言文本中识别出实体之间的语义关系的任务。给定一个句子和一个实体对,模型需要判断这两个实体之间是否存在某种预定义的语义关系。
输入:
- 一个句子,如:“Elon Musk is the CEO of Tesla.”
- 一对实体,如 (“Elon Musk”, “Tesla”)
输出:
- 实体之间的语义关系,如:“CEO_of”
二、为什么使用 BiLSTM + Attention
2.1 BiLSTM 的作用
- LSTM(长短期记忆网络) 是一种 RNN,用于捕捉文本中的时序依赖关系。
- BiLSTM(双向 LSTM) 同时从 前向 和 后向 编码输入序列的上下文信息,非常适合处理自然语言中的上下文语义理解。
2.2 Attention 的作用
- 在每个时间步,BiLSTM 输出的是一个隐藏状态序列,不同位置的信息对关系的判断贡献不一致。
- Attention 机制可以帮助模型 聚焦于与关系判断最相关的词或片段,提升模型的解释性与性能。
三、BiLSTM + Attention 的模型结构
3.1 输入表示层
- 对输入句子进行分词,并将每个词转换为向量(embedding)。
- 可以使用:
- 词嵌入(如 GloVe、Word2Vec)
- 预训练模型的 embedding(如 BERT 的输出)
- 也可加入实体位置信息,如相对距离特征。
3.2 BiLSTM 编码层
- 输入嵌入序列送入 BiLSTM,得到每个时间步的 前向与后向隐藏状态拼接向量:
h i = [ h → i ; h ← i ] h_i = [\overrightarrow{h}_i ; \overleftarrow{h}_i] hi=[hi;hi] - 输出维度为
[batch_size, seq_len, 2 * hidden_size]
。
3.3 Attention 聚合层
- 对 BiLSTM 的输出序列应用注意力机制:
- 学习一个注意力权重向量 α \alpha α,用来加权求和所有时间步的输出:
α i = exp ( e i ) ∑ j exp ( e j ) , e i = w T tanh ( W h i + b ) \alpha_i = \frac{\exp(e_i)}{\sum_j \exp(e_j)},\quad e_i = w^T \tanh(W h_i + b) αi=∑jexp(ej)exp(ei),ei=wTtanh(Whi+b) - 最终聚合成一个向量 v = ∑ i α i h i v = \sum_i \alpha_i h_i v=∑iαihi,表示句子的全局表示。
- 学习一个注意力权重向量 α \alpha α,用来加权求和所有时间步的输出:
3.4 分类层
- 将注意力输出的上下文向量 v v v 送入一个 全连接层 + Softmax,输出每种关系的概率分布:
y = softmax ( W v + b ) y = \text{softmax}(Wv + b) y=softmax(Wv+b)
四、训练方式
4.1 监督训练
- 使用 交叉熵损失函数(Cross Entropy Loss) 作为训练目标:
L = − ∑ i y i log ( y ^ i ) \mathcal{L} = -\sum_{i} y_i \log(\hat{y}_i) L=−i∑yilog(y^i)- y i y_i yi:真实标签的 one-hot 向量
- y ^ i \hat{y}_i y^i:模型预测的概率分布
4.2 数据格式(适用于句子级关系分类)
每个训练样本包含:
- 句子文本
- 实体对
- 实体之间的关系标签
可选增强方式:
- 在输入中加入 实体位置标记(例如用特殊符号标记 e1、e2)
- 使用 位置嵌入、实体类型嵌入
五、模型示意图
输入句子 + 实体对(如 "Elon Musk is the CEO of Tesla.")
↓
Token Embedding + Entity Embedding
↓
BiLSTM 编码
↓
Attention 权重加权求和
↓
上下文向量(全局表示)
↓
全连接层 + Softmax 分类
↓
预测关系(如 "CEO_of")
六、模型优势与适用场景
6.1 优势
特点 | 描述 |
---|---|
上下文感知 | BiLSTM 可以捕捉上下文中前后语义信息 |
聚焦关键部分 | Attention 可以专注于对关系判断最相关的词 |
结构简单、可解释性强 | Attention 权重可以解释模型的“注意力”来源 |
6.2 适用场景
- 给定实体对的关系分类任务
- 少数据任务(小样本)
- 需要可解释性的关系抽取任务
- 文本关系分类、事件三元组分类
七、模型局限与扩展方向
7.1 局限性
- 不能同时抽取多个三元组(无法处理三元组重叠问题)
- 输入固定长度(无法适应长文本关系跨句推理)
- 仅适用于句子级分类,不能做端到端的联合抽取
7.2 可扩展方向
- 加入 实体位置嵌入、实体类型嵌入
- 与 BERT、ERNIE 等预训练模型结合(即 BiLSTM 替换为 Transformer)
- 构建联合抽取模型(加入实体识别模块)
八、代码实现
1、数据样例(简化版)
sample_data = [
{
"text": "Elon Musk is the CEO of Tesla.",
"entity1": "Elon Musk",
"entity2": "Tesla",
"relation": "CEO_of"
},
{
"text": "Steve Jobs founded Apple.",
"entity1": "Steve Jobs",
"entity2": "Apple",
"relation": "founded"
}
]
relation2id = {
"CEO_of": 0,
"founded": 1,
"no_relation": 2
}
2、模型结构(BiLSTM + Attention)
import torch
import torch.nn as nn
class BiLSTMAttentionRE(nn.Module):
def __init__(self, embedding_dim, hidden_dim, num_classes, vocab_size, padding_idx=0):
super(BiLSTMAttentionRE, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
self.attention = nn.Linear(hidden_dim * 2, 1)
self.classifier = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, input_ids, attention_mask):
embeds = self.embedding(input_ids) # [batch, seq_len, embed_dim]
lstm_out, _ = self.lstm(embeds) # [batch, seq_len, hidden*2]
attn_weights = torch.softmax(self.attention(lstm_out).squeeze(-1).masked_fill(attention_mask == 0, -1e9), dim=1) # [batch, seq_len]
context = torch.sum(lstm_out * attn_weights.unsqueeze(-1), dim=1) # [batch, hidden*2]
logits = self.classifier(context)
return logits
3、词表与预处理函数
from collections import Counter
from transformers import BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def build_vocab(data):
word_counter = Counter()
for d in data:
words = tokenizer.tokenize(d["text"])
word_counter.update(words)
vocab = {
word: idx+2 for idx, (word, _) in enumerate(word_counter.most_common())}
vocab["[PAD]"] = 0
vocab["[UNK]"] = 1
return vocab
def encode_sentence(text, vocab, max_len=32):
tokens = tokenizer.tokenize(text)
token_ids = [vocab.get(t, vocab["[UNK]"]) for t in tokens][:max_len]
attention_mask = [1] * len(token_ids)
# padding
while len(token_ids) < max_len:
token_ids.append(vocab["[PAD]"])
attention_mask.append(0)
return torch.tensor(token_ids), torch.tensor(attention_mask)
4、训练代码
# 准备数据
vocab = build_vocab(sample_data)
num_classes = len(relation2id)
X, attention_masks, y = [], [], []
for item in sample_data:
x, attn = encode_sentence(item["text"], vocab)
X.append(x)
attention_masks.append(attn)
y.append(torch.tensor(relation2id[item["relation"]]))
X = torch.stack(X)
attention_masks = torch.stack(attention_masks)
y = torch.stack(y)
# 模型、损失、优化器
model = BiLSTMAttentionRE(embedding_dim=128, hidden_dim=64, num_classes=num_classes, vocab_size=len(vocab)).to("cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# 训练过程
for epoch in range(10):
model.train()
logits = model(X, attention_masks)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {
epoch+1}, Loss: {
loss.item():.4f}")
5、预测代码
def predict(model, text):
model.eval()
x, attn = encode_sentence(text, vocab)
x = x.unsqueeze(0)
attn = attn.unsqueeze(0)
with torch.no_grad():
logits = model(x, attn)
pred = torch.argmax(logits, dim=1).item()
return pred
test_text = "Elon Musk is the CEO of Tesla."
pred_id = predict(model, test_text)
id2rel = {
v: k for k, v in relation2id.items()}
print("预测关系:", id2rel[pred_id])
八、总结
模块 | 功能 | 关键技术 |
---|---|---|
输入编码 | 将句子和实体转为向量 | Token Embedding |
编码层 | 捕捉上下文 | BiLSTM |
聚合层 | 提取关键语义 | Attention |
输出层 | 预测关系标签 | 全连接层 + Softmax |
训练方式 | 监督学习 | 交叉熵损失 |
BiLSTM + Attention 是关系分类任务中经典的结构之一,具有良好的上下文建模能力和较强的可解释性,是从基础模型到复杂深度关系抽取系统的重要起点。