使用 Transformers 在你自己的数据集上训练文本分类模型
最近实在是有点忙,没啥时间写博客了。趁着周末水一文,把最近用 huggingface transformers 训练文本分类模型时遇到的一个小问题说下。
背景
之前只闻 transformers 超厉害超好用,但是没有实际用过。之前涉及到 bert 类模型都是直接手写或是在别人的基础上修改。但这次由于某些原因,需要快速训练一个简单的文本分类模型。其实这种场景应该挺多的,例如简单的 POC 或是临时测试某些模型。
我的需求很简单:用我们自己的数据集,快速训练一个文本分类模型,验证想法。
我觉得如此简单的一个需求,应该有模板代码。但实际去搜的时候发现,官方文档什么时候变得这么多这么庞大了?还多了个 Trainer
API?瞬间让我想起了 Pytorch Lightning 那个坑人的同名 API。但可能是时间原因,找了一圈没找到适用于自定义数据集的代码,都是用的官方、预定义的数据集。
所以弄完后,我决定简单写一个文章,来说下这原本应该极其容易解决的事情。
数据
假设我们数据的格式如下:
0 第一个句子
1 第二个句子
0 第三个句子
即每一行都是 label sentence
的格式,中间空格分隔。并且我们已将数据集分成了 train.txt
和 val.txt
。
代码
加载数据集
首先使用 datasets
加载数据集:
from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})
加载后的 dataset
是一个 DatasetDict
对象:
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 3
})
test: Dataset({
features: ['text'],
num_rows: 3
})
})
类似 tf.data
,此后我们需要对其进行 map
,对每一个句子进行 tokenize、padding、batch、shuffle:
def tokenize_function(examples):
labels = []
texts = []
for example in examples['text']:
split = example.split(' ', maxsplit=1)
labels.append(int(split[0]))
texts.append(split[1])
tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
tokenized['labels'] = labels
return tokenized
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)
根据数据集格式不同,我们可以在 tokenize_function
中随意自定义处理过程,以得到 text 和 labels。注意 batch_size
和 max_length
也是在此处指定。处理完我们便得到了可以输入给模型的训练集和测试集。
训练
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
你可以根据情况修改训练 batchsize per_device_train_batch_size
。
完整代码
完整代码见 GitHub。
END
最新文章
- spring注解 @Transactional
- 大话设计模式C++版——抽象工厂模式
- express-5 质量保证(2)
- BZOJ 2433 智能车比赛(计算几何+最短路)
- openerp 经典收藏 记录规则 – 销售只能看到自己的客户,经理可以看到全部(转载)
- Google文档
- JAVA跑马灯实现1
- 【Beta】Scrum10
- Volley源码学习笔记
- linux cpu load学习笔记
- UVa - 102 - Ecological Bin Packing
- mybatis中#{}与${}的区别
- 循环队列搜索 Search in Rotated Sorted Array
- docker使用ssh远程连接容器(没钱买服务器又不想安装虚拟机患者必备)
- Knockout.Js官网学习(checked 绑定)
- MUSIC分辨率与克拉美罗下界的关系
- C语言--第六周作业评分和总结(5班)
- spring中的控制反转
- RAC日常维护命令
- How to Install Xcode, Homebrew, Git, RVM, Ruby &; Rails on Snow Leopard, Lion, Mountain Lion, and Mavericks
热门文章
- Filter Pattern 2 (dubbo的实现方式)
- Python学习笔记组织文件之将美国风格日期的文件改名为欧洲风格的日期
- JavaScript 错误 throw、try、catch
- K8S informer机制
- 易语言json
- SpringBoot2.2.2+SpringCloud-Hoxton.SR1整合eureka/gateway
- 无法加载文件 D:\lunwen\nodejs\node_global\vue.ps1,因为在此系统上禁止运行脚本。visual code页面vue ui启动失败
- 肖sir_ 杭州_阿里和蚂蚁和菜鸟和支付宝面试题集锦
- ubuntu22.04安装mysql5.7
- scala调用fastjson JSON.toJSONString()序列化对象出错