基础信息说明

  • 本文以Seq2SeqTrainer作为实例,来讨论其模型训练时的数据加载方式
  • 预训练模型:opus-mt-en-zh
  • 数据集:本地数据集
  • 任务:en-zh 机器翻译

数据加载

Trainer的数据加载方式主要分为两种:基于torch.utils.data.Dataset的方式加载 和 基于huggingface自带的Datasets的方式(下文用huggingface / Datasets表示)加载。以下是一些需要注意的点:(1)Seq2SeqTrainer()的train_dataset和eval_dataset参数的所传实参应为字典类型;(2)该字典实参的keys应当覆盖模型运行所需要的数据参数(本文需要包括的有:'input_ids', 'attention_mask', 'labels');(3)使用huggingface / Datasets方法加载时,传给train_dataset和eval_dataset的字典实参中,多余的key(未在模型运行所需输入参数列表中)及其相关数据数,将会在训练之前被剔除。

torch.utils.data.Dataset

重载Dataset类(dataset.py)

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset class CDNDataset(Dataset):
def __init__(self, samples):
super(CDNDataset, self).__init__()
self.samples = samples def __getitem__(self, ite):
res = {k_: v_[ite]for k_, v_ in self.samples.items()}
return res def __len__(self):
return len(self.samples['labels'])

加载引用(main.py 后文代码同属于本文件)

from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataset import CDNDataset

读取数据

# 读取训练集
with open('raw_data/txt_en.txt', 'r', encoding='utf-8') as fr_en, open('raw_data/txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
train_data = tokenizer([str_.strip() for str_ in fr_en.readlines()], max_length=128, padding=True,truncation=True)
# 将tokenized的中文序列对应的input_ids作为输入数据的标签
train_data['labels'] = tokenizer([str_.strip() for str_ in fr_zh.readlines()], max_length=128,
padding=True,truncation=True)["input_ids"]
fr_en.close()
fr_zh.close()
train_data = CDNDataset(train_data) # 读取验证集
with open('raw_data/test_txt_en.txt', 'r', encoding='utf-8') as fr_en, open('raw_data/test_txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
dev_data = tokenizer([str_.strip() for str_ in fr_en.readlines()], max_length=128, padding=True,truncation=True)
dev_data['labels'] = tokenizer([str_.strip() for str_ in fr_zh.readlines()], max_length=128,
padding=True,truncation=True)["input_ids"]
fr_en.close()
fr_zh.close()
dev_data = CDNDataset(dev_data)

huggingface / Datasets

修改main.py中数据集读取部分的代码

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
"""利用load_dataset()来读取数据:
- 该方法支持.txt、.csv、.json等文件格式
- 返回结果是一个字典类型
- 读取.txt文件时,若不指定名称,这key为"text", 且会返回文本中的样本数(段落数)
- 在读取.json文件时,若所有样本放在一个josn文件中,则返回的样本数为1(无法优雅地调用train_test_split()进行数据集分割),名称为默认名或者最层字典所 对应的keys;
- 将每个json文件仅存放一个样本,并把这些文件放在某一目录,可使利用load_dataset()正确计算出样本数。但该目录下每个.json文件命名风格要一致(例如:txt1.json、txt2.json、、、),文件名差异较大的话,系统会只读取某一类命名格式相近的文件中的数据。
"""
books = load_dataset("raw_data", data_dir='test_en', name='translation') books = books["train"].train_test_split(test_size=0.15) source_lang = "en"
target_lang = "zh"
prefix = "translate English to Chinese: " # 其实我也还没搞懂为啥要加这样一个前缀 def preprocess_function(examples):
inputs = [prefix + example[source_lang] for example in examples["translation"]]
targets = [example[target_lang] for example in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=128, truncation=True) with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, truncation=True) model_inputs["labels"] = labels["input_ids"]
return model_inputs tokenized_books = books.map(preprocess_function, batched=True)

模型及参数加载

tokenizer = AutoTokenizer.from_pretrained("opus-mt-en-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("opus-mt-en-zh")
#使用huggingface/Datasets方式加载数据时,可以用DataCollator达到批处理的效果
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) #用torch.utils.data.Dataset方式加载时,不需要 training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=2,
fp16=True,
)

模型训练

本文以Seq2SeqTrainer作为实例来进行介绍。

trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=dev_data,
tokenizer=tokenizer,
data_collator=data_collator, #用torch.utils.data.Dataset方式加载时,此参数不需要
)

补充说明

  • Seq2SeqTrainer()中的train_dataset和eval_dataset参数仅支持torch.utils.data.Datasethuggingface / Datasets类型的传入实参。

  • torch.utils.data.Dataset类型的实参传入Seq2SeqTrainer()后,会在后序过程直接调用torch.utils.data.DataLoader,与常规pytorch操作相同

  • huggingface / Datasets类型的实参传入Seq2SeqTrainer()后,在后序过程会,先剔除多余的键及其值。至于torch.utils.data.Dataset类型的实参中若包含多余的键及其值,程序会不会报错暂没有测试过。获取模型所需输入参数列表的程序如下:

        def _set_signature_columns_if_needed(self):
    if self._signature_columns is None:
    # Inspect model forward signature to keep only the arguments it accepts.
    signature = inspect.signature(self.model.forward)
    self._signature_columns = list(signature.parameters.keys())
    # Labels may be named label or label_ids, the default data collator handles that.
    self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
  • 即使传递给train_dataset和eval_dataset的数据时字典类型,也存在一种做法使得基于torch.utils.data.Dataset加载数据的方式会报异常

    重载Dataset类(dataset.py)

    # -*- coding: utf-8 -*-
    from torch.utils.data import Dataset class CDNDataset(Dataset):
    def __init__(self, samples):
    super(CDNDataset, self).__init__()
    self.samples = samples def __getitem__(self, ite):
    return self.samples[ite] def __len__(self):
    return len(self.samples)

    main.py中数据读取部分

    train_data = []
    with open('raw_data/txt_en.txt', 'r', encoding='utf-8') as fr_en,
    open('raw_data/txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
    for en_, zh_ in zip(fr_en, fr_zh):
    data = tokenizer(en_.strip(), max_length=128, padding=True, truncation=True, return_tensors='pt')
    data["labels"] = tokenizer(zh_.strip(), max_length=128, padding=True,
    truncation=True, return_tensors='pt')['input_ids']
    train_data.append(data)
    fr_en.close()
    fr_zh.close() train_data = CDNDataset(train_data)
    dev_data = []
    with open('raw_data/test_txt_en.txt', 'r', encoding='utf-8') as fr_en,
    open('raw_data/test_txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
    for en_, zh_ in zip(fr_en, fr_zh):
    data = tokenizer(en_.strip(), max_length=128, padding=True, truncation=True, return_tensors='pt')
    data["labels"] = tokenizer(zh_.strip(), max_length=128, padding=True,
    truncation=True, return_tensors='pt')['input_ids']
    dev_data.append(data)
    fr_en.close()
    fr_zh.close()
    dev_data = CDNDataset(dev_data)

    报错信息:

    "Unable to create tensor, you should probably activate truncation and/or padding "

    ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length.

    说明:

    现将前一个基于torch.utils.data.Dataset加载数据的方式的案例叫作method1,当前抛出异常的案例叫作method2,两者相比:

    • dataset.py中__getitem__()的返回类型都是字典,每次也都是返回一个样本
    • 在main.py中:
      • method1将所有序列样本存入一个list中,然后对该list进行了一次tokenize,最后在CDNDataset类的__getitem__()中根据索引ite组合成一个样本的字典格式,并返回
      • method2中是先对每个像本序列分别作tokenize,再将各个样本tokenize后得到的字典存入一个list,最后在CDNDataset类的__getitem__()中根据索引ite返回各个样本对应的字典
    • 所报错误信息说没有进行paddingtruncation, 但事实上我做了,故而不知道是啥问题,望各位大佬不吝赐教。谢过!

最新文章

  1. js最详细的基础,jquery 插件最全的教材
  2. 创建第一个JBPM6项目并且运行自带的helloword例子(JBPM6学习之三)
  3. JAVA - 大数类详解
  4. 实现dom元素拖动
  5. 2016.04.09 使用Powerdesigner进行创建数据库的概念模型并转为物理模型
  6. Devlop Win 8 and Windows Phone App for Microsoft Dynamics CRM
  7. .NET基础拾遗(5)反射1
  8. [Jobdu] 题目1506:求1+2+3+...+n
  9. kafka消息传输时的对象转字符串时所需 -json String 转list 、set、 Long、 String 、map 与json Iterator遍历
  10. 《Spark大数据处理:技术、应用与性能优化》【PDF】
  11. 3.3.4 PCI设备进行DMA写时发生Cache命中
  12. cesium 之加载地形图 Terrain 篇(附源码下载)
  13. Java 关于cannot resolve symbol 'log'报错问题
  14. HBuilder + PHP开发环境配置
  15. 深度学习框架比较TensorFlow、Theano、Caffe、SciKit-learn、Keras
  16. oracle 新建用户后赋予的权限语句
  17. AndroidStudio连不上Android设备真机
  18. SerialPort.h SerialPort.cpp
  19. sqli-labs(三)
  20. 讲讲亿级PV的负载均衡架构

热门文章

  1. Flutter异常监控 - 叁 | 从bugsnag源码学习如何追溯异常产生路径
  2. Keil 5(Keil C51)安装与注册 [ 图文教程 ]
  3. Java 进阶P-5.1+P-5.2
  4. 构建api gateway之 健康检查
  5. T02 ExtractSubject 项目开发总结
  6. 【Oculus Interaction SDK】(七)使用射线进行交互(物体 & UI)
  7. JAVA虚拟机09---垃圾回收---经典垃圾回收器
  8. 接口介绍以及postman的基本使用
  9. 2021级《JAVA语言程序设计》上机考试试题9
  10. ABAP 辨析ON INPUT|REQUEST|CHAIN-INPUT|CHAIN-REQUEST