最近在做BERT的fine-tune工作,记录一下阅读项目https://github.com/weizhepei/BERT-NER时梳理的训练pipline,该项目基于Google的Transformers代码构建

前置知识

bert的DataLoader简介(真的很简介)

https://zhuanlan.zhihu.com/p/384469908

yield介绍

https://www.runoob.com/w3cnote/python-yield-used-analysis.html

这是一种提高代码复用性的方法

带yield的函数被称为 generator(生成器),调用next()方法可使其执行至函数内部的yield处中断并返回一个迭代值

Pipeline

训练部分

① 运行build_dataset_tags.py将原始数据集处理为txt文本保存(生成原始数据集文本)

②数据流

注:“XX.py--->”代表该过程由XX.py发起

1、train.py--->class DataLoader[data_loader.py]--->train_data(d)

通过data_loader.py中的load_data,再调用load_sentences_tags

load_sentences_tags返回一个字典d,包含:

  • 使用tokenizer对原始句子的token

  • token对应的id

  • token对应的tag

  • 句子的长度

2、train.py--->train_and_evaluate(train_data, val_data)--->2个generator--->evaluate()[evaluate.py]

此处生成的两个生成器分别用于在训练和测试时以迭代方式获取batch数据

3、train.py--->evaluate(generator)[evaluate.py]--->batch_data, batch_token_starts, batch_tags--->将batch输入model[在train.py处实例化]中--->loss、batch_output、batch_tags--->计算出F1值返回给train_and_evaluate()

在得到F1值后,根据设置的参数决定是否满足停止训练的条件

数据迭代器

data_loader.py--->data_iterator(train/val/test_data)

--->计算会产生的batch的数量(由train/val/test_data中记录的句子长度size和class DataLoader中人为设置的batch_size参数决定)--->提取train/val/test_data中的sentences、tags

# 计算batch数
if data['size'] % self.batch_size == 0:
BATCH_NUM = data['size']//self.batch_size
else:
BATCH_NUM = data['size']//self.batch_size + 1 # one pass over data
# 提取一个batch,由batch_size个sentences构成
for i in range(BATCH_NUM):
# fetch sentences and tags
if i * self.batch_size < data['size'] < (i+1) * self.batch_size:
sentences = [data['data'][idx] for idx in order[i*self.batch_size:]]
if not interMode:
tags = [data['tags'][idx] for idx in order[i*self.batch_size:]]
else:
sentences = [data['data'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
if not interMode:
tags = [data['tags'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]

--->计算batch中最大的句子长度--->将数据转换为np矩阵(numpy array)

--->将数据拷贝到另一个np矩阵,使得所有数据的长度与最大句子长度保持一致(即完成了padding)

# prepare a numpy array with the data, initialising the data with pad_idx
# batch_data的形状为:最长句子长度X最长句子长度(batch_len X batch_len),元素全为0 batch_data = self.token_pad_idx * np.ones((batch_len, max_subwords_len))
batch_token_starts = [] # copy the data to the numpy array
for j in range(batch_len):
cur_subwords_len = len(sentences[j][0])
if cur_subwords_len <= max_subwords_len:
batch_data[j][:cur_subwords_len] = sentences[j][0]
else:
batch_data[j] = sentences[j][0][:max_subwords_len]
token_start_idx = sentences[j][-1]
token_starts = np.zeros(max_subwords_len)
token_starts[[idx for idx in token_start_idx if idx < max_subwords_len]] = 1
batch_token_starts.append(token_starts)
max_token_len = max(int(sum(token_starts)), max_token_len)

--->将所有索引格式的(我理解就是numpy array形式的)数据转换为torch LongTensors

--->返回batch_data, batch_token_starts, batch_tags(这就是用于直接输入模型的数据)

最新文章

  1. 【摘】linux中fstab解说
  2. 数据库的char(n)
  3. Codeforces 424C(异或)
  4. linux权限管理_ACL权限
  5. SMTP邮件发送命令
  6. cocos2dx 2.14使用UUID
  7. -_-#【jQuery插件】Spinner 数字选择器
  8. 基于mini2440的看门狗(裸机)
  9. c语言求最大公约数和最小公倍数
  10. 什么场景Hbase
  11. Android简单逐帧动画Frame的实现(二)
  12. 面向对象重写(override)与重载(overload)区别
  13. kali系统破解WPA密码实战
  14. win7+ ubuntu 双系统
  15. MySql入门(1)
  16. LINUX改变文件大小
  17. Mysql:索引实战
  18. 1.7Oob对象的创建局部变量
  19. java String[] 初始化
  20. Kafka、RabbitMQ、RocketMQ等消息中间件的对比

热门文章

  1. Table.LastN保留后面N….Last…(Power Query 之 M 语言)
  2. bootstrap.css 进度条没有动画效果
  3. CF355B Vasya and Public Transport 题解
  4. CF792A New Bus Route 题解
  5. SQL Server日志恢复还原数据
  6. DG修复:异常关库导致的数据库启动失败ORA-01110及GAP修复
  7. centos使用yum安装报错: 另一个应用程序是:PackageKit
  8. c++11之字符串格式化
  9. 【LeetCode】935. Knight Dialer 解题报告(Python)
  10. 【LeetCode】52. N-Queens II 解题报告(Python & C+)