ERICA-提升预训练语言模型实体与关系理解的统一框架解读

本文最后更新于:2024年8月14日 下午

ERICA-提升预训练语言模型实体与关系理解的统一框架解读

这篇被ACL 2021主会录用的文章中,清华大学联合腾讯微信模式识别中心与UIUC,提出了一种新颖的对比学习框架ERICA,帮助PLM深入了解文本中的实体及实体间关系。具体来说,作者提出了两个辅助性预训练任务来帮助PLM更好地理解实体和实体间关系

  • 实体区分任务,给定头实体和关系,推断出文本中正确的尾实体
  • 关系判别任务,区分两个关系在语义上是否接近,这在长文本情景下涉及复杂的关系推理

实验结果表明,ERICA在不引入额外神经网络参数的前提下,仅仅对PLM进行少量的额外训练,就可以提升典型PLM(例如BERT 和 RoBERTa)在多种自然语言理解任务上(包括关系抽取、实体类别区分、问题回答等)的性能。尤其是在低资源(low-resource)的设定下,性能的提升更加明显。

论文链接:https://arxiv.org/abs/2012.15022

Github仓库:https://github.com/thunlp/ERICA

ERICA-colab

如果是想简单运行测试(白嫖算力)的同学,我将我的colab修改版公开了,可以在google colab等在线jupyter notebook平台直接运行

Github仓库:https://github.com/chaoers/ERICA-colab

Colab: https://colab.research.google.com/drive/1x0ZQF8KZqSczzLCEt3xXkbpWLWmzOgBW?usp=sharing

主要是修改了一些错误代码和添加了一些参数以便可以在jupyter notebook上直接运行。(这个源代码仓库有的地方真的写的很反人类)

相关工作及创新点

传统的预训练模型(PLM)没有对文本中的关系事实进行显式建模,而这些关系事实对于理解文本至关重要。为了解决这个问题,一些研究人员试图改进 PLM 的架构、预训练任务等,以更好地理解实体之间的关系。但是它们通常只对文本中的句子级别的单个关系进行建模,不仅忽略了长文本场景下多个实体之间的复杂关系,也忽略了对实体本身的理解,例如图1中所展现的,对于长文本来说,为了让PLM更加充分理解地单个实体,我们需要考虑该实体和其他实体之间的复杂关系;而这些复杂的关系的理解通常涉及复杂的推理链,往往需要综合多个句子的信息得出结论。针对这两个痛点,本文提出了实体区分任务和关系区分任务来增强PLM对于实体和实体间关系的理解。

640

论文细节解读

本文主要工作集中在预训练(PLM)部分,主要创新点就是提出了两个辅助性预训练任务来帮助PLM更好地理解实体和实体间关系:

  • 实体区分任务(Entity Discrimination Task, ED),给定头实体和关系,推断出文本中正确的尾实体。
  • 关系区分任务(Relation Discrimination Task, RD),区分两个关系在语义上是否接近,这在长文本情景下涉及复杂的关系推理。

数据集处理及实体表示

640-20210717165035122

ERICA的训练依赖于大规模文档级远程监督数据,该数据的构造有三个阶段:首先从wikipedia中爬取文本段落,然后用命名实体识别工具(例如spacy)进行实体标注,将所有获得的实体和wikidata中标注的实体对应上,并利用远程监督(distant supervision)信号获得实体之间可能存在的关系,最终保留长度在128到512之间,含有多于4个实体,实体间多于4个远程监督关系的段落。注意这些远程监督的关系中存在大量的噪声,而大规模的预训练可以一定程度上实现降噪。作者也开源了由100万个文档组成的大规模远程监督预训练数据。

640-20210717164757362

鉴于每个实体可能在段落中出现多次,并且每次出现时对应的描述(mention)可能也不一样,本文在使用PLM对tokenize后的段落进行编码后,取每个描述的所有token均匀池化后的结果作为该描述的表示,接着对于全文中该实体所有的描述进行第二次均匀池化,得到该实体在该文档中的表示;对于两个实体,它们之间的关系表示为两个实体表示的简单拼接。以上是最简单的实体/实体间关系的表示方法,不需要引入额外的神经网络参数。

实体区分任务 Entity Discrimination Task

实体区分任务旨在给定头实体和关系,从当前文档中寻找正确的尾实体。例如在图2中,Sinaloa和Mexico具有country的远程关系,于是作者将关系country和头实体Sinaloa拼接在原文档的前面作为提示(prompt),在此条件下区分正确的尾实体的任务可以在对比学习的框架下转换成拉近头实体和正确尾实体的实体表示的距离,推远头实体和文档中其它实体(负样本)的实体表示的距离

640-20210717155911682

具体loss公式为:$$\mathcal{L}{ED} = - \sum{t^i_{jk}\in \mathcal{T}^+} \log \frac{exp(cos(e_{ij}, e_{ik}/\tau))}{\sum^{|\varepsilon_i|}{l=1, l\neq j} exp(cos(e{ij}, e_{il}/\tau))}$$

关系区分任务 Relation Discrimination Task

关系区分任务旨在区分两个关系的表示在语义空间上的相近程度。由于作者采用文档级而非句子级的远程监督,文档中的关系区分涉及复杂的推理链。具体而言,作者随机采样多个文档,并从每个文档中得到多个关系表示,这些关系可能只涉及句子级别的推理,也可能涉及跨句子的复杂推理。之后基于对比学习框架,根据远程监督的标签在关系空间中对不同的关系表示进行训练,如前文所述,每个关系表示均由文档中的两个实体表示构成。正样本即具有相同远程监督标签的关系表示,负样本与此相反。作者在实验中还发现进一步引入不具有远程监督关系的实体对作为负样本可以进一步提升模型效果。由于进行对比训练的两个关系表示可能来自于多个文档,也可能来自于单个文档,因此文档间/跨文档的关系表示交互都得到了实现。巧妙的是,对于涉及复杂推理的关系,该方法不需要显示地构建推理链,而是“强迫”模型理解这些关系并在顶层的关系语义空间中区分这些关系。

具体loss公式为:

$$\mathcal{L}_{RD}^{\mathcal{T}_1, \mathcal{T}2} = - \sum{t_A \in \mathcal{T}1, t_B \in \mathcal{T}2} \log \frac{exp(cos(r{t_A}, r{t_B}/\tau))}{\mathcal{Z}}$$

$$\mathcal{Z} = \sum^N_{t_C \in \tau / t_A} exp(cos(r_{t_A}, r_{t_C}/\tau))$$

$$\mathcal{L}{RD} = \mathcal{L}{RD}^{\tau_s^+, \tau_s^+} + \mathcal{L}{RD}^{\tau_s^+, \tau_c^+} + \mathcal{L}{RD}^{\tau_c^+, \tau_s^+} + \mathcal{L}_{RD}^{\tau_c^+, \tau_c^+}$$

训练目标整合

为了避免灾难性遗忘,作者将上述两个任务同masked language modeling (MLM)任务一起训练,总的训练目标如下所示

$$\mathcal{L} = \mathcal{L}{ED} + \mathcal{L}{RD} + \mathcal{L}_{MLM}$$

关键代码解读

数据预处理

本部分的关键代码在pretrain/prepare_pretrain_data目录下

主要用到以下文件

  • get_distant.py:数据清洗,实体抽取和关系抽取
  • remove_test_set.py:区分训练集和测试集
  • sample_data.py:tokenized化,通过这样预处理

实体关系抽取

实体关系抽取实际上就是通过以下文件进行匹配最后tokenized化

  • all_triple.txt:定义了实体之间的关系
  • all_name_to_Q.json:实体名到类型的一个json
  • all_Q.json:所有实体类型id的

tokenized化

tokenized用的是huggingface框架提供的BertTokenizer, RobertaTokenizer

pretrain/data/DOC/sampled_data/下就是官方给出的一个预处理完的示例结果

预训练模型

本部分的关键代码在pretrain/code/pretrain目录下,我们主要关注的是loss部分,因此只关注这部分代码,主要集中在model.py文件中。看代码发现作者没有直接将ED+RD+MLM loss放到一起训练,而是将RD+MLM定义为get_doc_loss函数,将ED+MLM定义为get_wiki_loss函数,并且在训练时通过参数决定使用哪种loss

1
2
3
4
5
6
7
8
9
10
11
# model.py

def forward(self, batch, doc_loss = 0, wiki_loss = 0):
if doc_loss + wiki_loss != 1:
assert False
if doc_loss == 1:
m_loss_d, r_loss_d = self.get_doc_loss(**batch[0])
return m_loss_d, r_loss_d
elif wiki_loss == 1:
m_loss_w, r_loss_w = self.get_wiki_loss(**batch[1])
return m_loss_w, r_loss_w

RD+MLM loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# model.py

def get_doc_loss(self, context_idxs, h_mapping, t_mapping, relation_label, relation_label_idx, context_masks, rel_mask_pos, rel_mask_neg, pos_num, mlm_mask):
if self.args.doc_loss:
m_input, m_labels = mask_tokens(context_idxs.cpu(), self.tokenizer, mlm_mask.cpu())
m_outputs = self.model(input_ids=m_input, masked_lm_labels=m_labels, attention_mask=context_masks)
m_loss = m_outputs[1]

context_output = m_outputs[0]
start_re_output = torch.matmul(h_mapping, context_output)
end_re_output = torch.matmul(t_mapping, context_output)
hidden = torch.cat([start_re_output, end_re_output], dim = 2)

pair_hidden = []
for i in range(relation_label_idx.size()[0]):
pair_hidden.append(hidden[relation_label_idx[i][0], relation_label_idx[i][1]])
pair_hidden = torch.stack(pair_hidden, dim = 0)

def get_all_pairs_indices(labels, rel_mask_pos, rel_mask_neg):
ref_labels = labels
labels1 = labels.unsqueeze(1)
labels2 = ref_labels.unsqueeze(0)
matches = (labels1 == labels2).byte()
diffs = matches ^ 1
matches = matches * rel_mask_pos
diffs = diffs * rel_mask_neg
a1_idx = matches.nonzero()[:, 0].flatten()
p_idx = matches.nonzero()[:, 1].flatten()
a2_idx = diffs.nonzero()[:, 0].flatten()
n_idx = diffs.nonzero()[:, 1].flatten()
return a1_idx, p_idx, a2_idx, n_idx

indices_tuple = get_all_pairs_indices(relation_label, rel_mask_pos, rel_mask_neg)

r_loss = self.ntxloss_doc(pair_hidden, relation_label, indices_tuple, pos_num)
else:
m_loss = 0
r_loss = 0

return m_loss, r_loss

ED+MLM loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# model.py

def get_wiki_loss(self, context_idxs, h_mapping, query_mapping, context_masks, rel_mask_pos, rel_mask_neg, relation_label_idx, start_positions, end_positions, mlm_mask):
m_input, m_labels = mask_tokens(context_idxs.cpu(), self.tokenizer, mlm_mask.cpu())
m_outputs = self.model(input_ids=m_input, masked_lm_labels=m_labels, attention_mask=context_masks)
m_loss = m_outputs[1]
context_output = m_outputs[0]

if self.args.wiki_loss == 1:
query_re_output = torch.matmul(query_mapping.unsqueeze(dim = 1), context_output)
query_re_output = query_re_output.squeeze(dim = 1)
start_re_output = torch.matmul(h_mapping, context_output)
new_start_re_output = []
for i in range(relation_label_idx.size()[0]):
new_start_re_output.append(start_re_output[relation_label_idx[i][0], relation_label_idx[i][1]])
start_re_output = torch.stack(new_start_re_output, dim = 0)
query_re_output = query_re_output

def get_all_pairs_indices(rel_mask_pos, rel_mask_neg):
a1_idx = rel_mask_pos.nonzero()[:, 0].flatten()
p_idx = rel_mask_pos.nonzero()[:, 1].flatten()
a2_idx = rel_mask_neg.nonzero()[:, 0].flatten()
n_idx = rel_mask_neg.nonzero()[:, 1].flatten()
return a1_idx, p_idx, a2_idx, n_idx

indices_tuple = get_all_pairs_indices(rel_mask_pos, rel_mask_neg)
r_loss = self.ntxloss_wiki(query_re_output, start_re_output, indices_tuple)
else:
r_loss = 0
return m_loss, r_loss

其中m_loss就是MLM loss,用huggingface框架的BertForMaskedLM或RobertaForMaskedLM实现的

r_loss对应的是RD/ED loss都是在对比学习的框架下实现的(正负样例对比进行学习,后面有时间的话会写一点介绍),因此代码相似性很高,计算loss模块都是公用的(ntxent_loss.py)最主要的是正负样例构建方法不同,即两个函数中get_all_pairs_indices部分,对应论文中提出的loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# ntxent_loss.py

import torch
from .generic_pair_loss import GenericPairLoss

class NTXentLoss(GenericPairLoss):

def __init__(self, temperature, **kwargs):
super().__init__(**kwargs, use_similarity=True, mat_based_loss=False)
self.temperature = temperature

def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
a1, _, a2, _ = indices_tuple
if len(a1) > 0 and len(a2) > 0:
pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
neg_pairs = neg_pairs / self.temperature
n_per_p = (a2.unsqueeze(0) == a1.unsqueeze(1)).float()
neg_pairs = neg_pairs*n_per_p
neg_pairs[n_per_p==0] = float('-inf')
max_val = torch.max(pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0].half())
numerator = torch.exp(pos_pairs - max_val).squeeze(1)
denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
log_exp = torch.log((numerator/denominator) + 1e-20)
return torch.mean(-log_exp)
return 0

代码小结

代码部分高度基于huggingface框架,基本就是改了一下loss然后就直接开始训练

  • 以后我们有什么自己的策略想法,想落地实现,其实就是仿效改这个函数,能快速使用transfomers实现,甚至直接二次预训练即可
  • 遇到一些特殊需求需要改huggingface框架也不是不可以,直接下载transformers代码进行需求修改即可

ERICA-提升预训练语言模型实体与关系理解的统一框架解读
https://asteriscus.cat/posts/50a3d0/
作者
Asterisk
发布于
2021年7月17日
许可协议