关于bert+lstm+crf实体识别训练数据的构建
一.在实体识别中,bert+lstm+crf也是近来常用的方法。这里的bert可以充当固定的embedding层,也可以用来和其它模型一起训练fine-tune。大家知道输入到bert中的数据需要一定的格式,如在单个句子的前后需要加入"[CLS]"和“[SEP]”,需要mask等。下面使用pad_sequences对句子长度进行截断以及padding填充,使每个输入句子的长度一致。构造训练集后,下载中文的预训练模型并加载相应的模型和词表vocab以参数配置,最后并利用albert抽取句子的embedding,这个embedding可以作为一个下游任务和其它模型进行组合完成特定任务的训练。
import torch
from configs.base import config
from model.modeling_albert import BertConfig, BertModel
from model.tokenization_bert import BertTokenizer
from keras.preprocessing.sequence import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, RandomSampler import os device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
MAX_LEN = 10
if __name__ == '__main__':
bert_config = BertConfig.from_pretrained(str(config['albert_config_path']), share_type='all')
base_path = os.getcwd()
VOCAB = base_path + '/configs/vocab.txt' # your path for model and vocab
tokenizer = BertTokenizer.from_pretrained(VOCAB) # encoder text
tag2idx={'[SOS]':101, '[EOS]':102, '[PAD]':0, 'B_LOC':1, 'I_LOC':2, 'O':3}
sentences = ['我是中华人民共和国国民', '我爱祖国']
tags = ['O O B_LOC I_LOC I_LOC I_LOC I_LOC I_LOC O O', 'O O O O'] tokenized_text = [tokenizer.tokenize(sent) for sent in sentences]
#利用pad_sequence对序列长度进行截断和padding
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_text], #没法一条一条处理,只能2-d的数据,即多于一条样本,但是如果全部加载到内存是不是会爆
maxlen=MAX_LEN-2,
truncating='post',
padding='post',
value=0) tag_ids = pad_sequences([[tag2idx.get(tok) for tok in tag.split()] for tag in tags],
maxlen=MAX_LEN-2,
padding="post",
truncating="post",
value=0) #bert中的句子前后需要加入[CLS]:101和[SEP]:102
input_ids_cls_sep = []
for input_id in input_ids:
linelist = []
linelist.append(101)
flag = True
for tag in input_id:
if tag > 0:
linelist.append(tag)
elif tag == 0 and flag:
linelist.append(102)
linelist.append(tag)
flag = False
else:
linelist.append(tag)
if tag > 0:
linelist.append(102)
input_ids_cls_sep.append(linelist) tag_ids_cls_sep = []
for tag_id in tag_ids:
linelist = []
linelist.append(101)
flag = True
for tag in tag_id:
if tag > 0:
linelist.append(tag)
elif tag == 0 and flag:
linelist.append(102)
linelist.append(tag)
flag = False
else:
linelist.append(tag)
if tag > 0:
linelist.append(102)
tag_ids_cls_sep.append(linelist) attention_masks = [[int(tok > 0) for tok in line] for line in input_ids_cls_sep] print('---------------------------')
print('input_ids:{}'.format(input_ids_cls_sep))
print('tag_ids:{}'.format(tag_ids_cls_sep))
print('attention_masks:{}'.format(attention_masks)) # input_ids = torch.tensor([tokenizer.encode('我 是 中 华 人 民 共 和 国 国 民', add_special_tokens=True)]) #为True则句子首尾添加[CLS]和[SEP]
# print('input_ids:{}, size:{}'.format(input_ids, len(input_ids)))
# print('attention_masks:{}, size:{}'.format(attention_masks, len(attention_masks))) inputs_tensor = torch.tensor(input_ids_cls_sep)
tags_tensor = torch.tensor(tag_ids_cls_sep)
masks_tensor = torch.tensor(attention_masks) train_data = TensorDataset(inputs_tensor, tags_tensor, masks_tensor)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=2) model = BertModel.from_pretrained(config['bert_dir'],config=bert_config)
model.to(device)
model.eval()
with torch.no_grad():
'''
note:
一.
如果设置:"output_hidden_states":"True"和"output_attentions":"True"
输出的是: 所有层的 sequence_output, pooled_output, (hidden_states), (attentions)
则 all_hidden_states, all_attentions = model(input_ids)[-2:] 二.
如果没有设置:output_hidden_states和output_attentions
输出的是:最后一层 --> (output_hidden_states, output_attentions)
'''
for index, batch in enumerate(train_dataloader):
batch = tuple(t.to(device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch
last_hidden_state = model(input_ids = b_input_ids,attention_mask = b_input_mask)
print(len(last_hidden_state))
all_hidden_states, all_attentions = last_hidden_state[-2:] #这里获取所有层的hidden_satates以及attentions
print(all_hidden_states[-2].shape)#倒数第二层hidden_states的shape
print(all_hidden_states[-2])
二.打印结果
input_ids:[[101, 2769, 3221, 704, 1290, 782, 3696, 1066, 1469, 102], [101, 2769, 4263, 4862, 1744, 102, 0, 0, 0, 0]]
tag_ids:[[101, 3, 3, 1, 2, 2, 2, 2, 2, 102], [101, 3, 3, 3, 3, 102, 0, 0, 0, 0]]
attention_masks:[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
4
torch.Size([2, 10, 768])
tensor([[[-1.1074, -0.0047, 0.4608, ..., -0.1816, -0.6379, 0.2295],
[-0.1930, -0.4629, 0.4127, ..., -0.5227, -0.2401, -0.1014],
[ 0.2682, -0.6617, 0.2744, ..., -0.6689, -0.4464, 0.1460],
...,
[-0.1723, -0.7065, 0.4111, ..., -0.6570, -0.3490, -0.5541],
[-0.2028, -0.7025, 0.3954, ..., -0.6566, -0.3653, -0.5655],
[-0.2026, -0.6831, 0.3778, ..., -0.6461, -0.3654, -0.5523]],
[[-1.3166, -0.0052, 0.6554, ..., -0.2217, -0.5685, 0.4270],
[-0.2755, -0.3229, 0.4831, ..., -0.5839, -0.1757, -0.1054],
[-1.4941, -0.1436, 0.8720, ..., -0.8316, -0.5213, -0.3893],
...,
[-0.7022, -0.4104, 0.5598, ..., -0.6664, -0.1627, -0.6270],
[-0.7389, -0.2896, 0.6083, ..., -0.7895, -0.2251, -0.4088],
[-0.0351, -0.9981, 0.0660, ..., -0.4606, 0.4439, -0.6745]]])
最新文章
- linux(centos)源码安装git
- python(一)入门
- linux指令大全(完整篇)(转)
- yum局域网软件源搭建
- java微信公众平台开发
- babel的使用详解
- Linux(Deepin 15.9) - MySQL5.7 安装
- 第十五节 JS面向对象实例及高级
- 使用eclipse整合ssh项目的例子--lljf(1)
- python requests库爬取网页小实例:ip地址查询
- [转帖]linux 清空history以及记录原理
- 五.Bash Shell编程基础入门实战
- C#:前台线程后台线程
- Apache kylin进阶——元数据篇
- asp.net 结合本地jQuery使在提交时显示错误提示
- 菜鸟教程之工具使用(七)——从GIt上导出Maven项目
- django前篇
- 《it项目管理那些事》学习笔记
- linux下php redis扩展安装
- 第一个springMVC入门程序