Transformers 库常见的用例 | 三
作者|huggingface
编译|VK
来源|Github
本章介绍使用Transformers库时最常见的用例。可用的模型允许许多不同的配置,并且在用例中具有很强的通用性。这里介绍了最简单的方法,展示了诸如问答、序列分类、命名实体识别等任务的用法。
这些示例利用Auto Model
,这些类将根据给定的checkpoint实例化模型,并自动选择正确的模型体系结构。有关详细信息,请查看:AutoModel
文档。请随意修改代码,使其更具体,并使其适应你的特定用例。
- 为了使模型能够在任务上良好地执行,必须从与该任务对应的checkpoint加载模型。这些checkpoint通常是在大量数据上预先训练的,并针对特定任务进行微调。这意味着:并非所有模型都针对所有任务进行了微调。如果要对特定任务的模型进行微调,可以利用examples目录中的
run\$task.py
脚本。 - 微调模型是在特定的数据集上微调的。此数据集可能与你的用例和域重叠,也可能不重叠。如前所述,你可以利用示例脚本来微调模型,也可以创建自己的训练脚本。
为了对任务进行推理,库提供了几种机制:
- 管道是非常易于使用的抽象,只需要两行代码。
- 直接将模型与Tokenizer(PyTorch/TensorFlow)结合使用来使用模型的完整推理。这种机制稍微复杂,但是更强大。
这里展示了两种方法。
请注意,这里介绍的所有任务都利用了在预训练模型针对特定任务进行微调后的模型。加载未针对特定任务进行微调的checkpoint时,将只加载transformer层,而不会加载用于该任务的附加层,从而随机初始化该附加层的权重。这将产生随机输出。
序列分类
序列分类是根据已经给定的类别然后对序列进行分类的任务。序列分类的一个例子是GLUE数据集,它就是完全基于该任务的。如果你想在GLUE序列分类任务上微调模型,可以利用run_GLUE.py
或run_tf_GLUE.py
脚本。
下面是一个使用管道进行情绪分析的例子:识别该序列是积极的还是消极的。它利用sst2上的微调模型,这是一个GLUE任务。
from transformers import pipeline
nlp = pipeline("sentiment-analysis")
print(nlp("I hate you"))
print(nlp("I love you"))
这将返回一个标签(“积极”或“消极”)和一个分数,如下所示:
[{'label': 'NEGATIVE', 'score': 0.9991129}]
[{'label': 'POSITIVE', 'score': 0.99986565}]
下面是一个使用模型进行序列分类的示例,以确定两个序列是否是彼此的解释。该过程如下:
- 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
- 从这两句话中构建一个序列,使用正确的特定于模型的分隔符标记类型id和注意力掩码(encode()和encode_plus()处理这个问题)
- 将这个序列传递到模型中,以便将其分类到两个可用的类中的一个:0(不是解释)和1(是解释)
- 计算结果的softmax获取类的概率
- 打印结果
Pytorch代码
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt")
paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0]
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]
print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
TensorFlow代码
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="tf")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="tf")
paraphrase_classification_logits = model(paraphrase)[0]
not_paraphrase_classification_logits = model(not_paraphrase)[0]
paraphrase_results = tf.nn.softmax(paraphrase_classification_logits, axis=1).numpy()[0]
not_paraphrase_results = tf.nn.softmax(not_paraphrase_classification_logits, axis=1).numpy()[0]
print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
这将输出以下结果:
Should be paraphrase
not paraphrase: 10%
is paraphrase: 90%
Should not be paraphrase
not paraphrase: 94%
is paraphrase: 6%
抽取式问答
抽取式问答是从给定问题的文本中抽取答案的任务。问答数据集的一个例子是SQuAD数据集,它完全基于该任务。如果你想在团队任务中微调模型,可以利用run_SQuAD.py
。
下面是一个使用管道进行问答的示例:从给定问题的文本中提取答案。它利用了一个小队的微调模型。
from transformers import pipeline
nlp = pipeline("question-answering")
context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
a model on a SQuAD task, you may leverage the `run_squad.py`.
"""
print(nlp(question="What is extractive question answering?", context=context))
print(nlp(question="What is a good example of a question answering dataset?", context=context))
这将返回从文本中提取的答案,一个置信度,以及“开始”和“结束”值,这些值是提取的答案在文本中的位置。
{'score': 0.622232091629833, 'start': 34, 'end': 96, 'answer': 'the task of extracting an answer from a text given a question.'}
{'score': 0.5115299158662765, 'start': 147, 'end': 161, 'answer': 'SQuAD dataset,'}
下面是一个使用模型和Tokenizer回答问题的示例。该过程如下:
- 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
- 定义一段文本和几个问题。
- 遍历问题并根据文本和当前问题构建一个序列,使用正确的模型特定分隔符标记类型id和注意力掩码将此序列传递到模型中。这将输出整个序列标记(问题和文本)的开始位置和结束位置的一系列分数。
- 计算结果的softmax以获得从标记的开始位置和停止位置对应的概率
- 将这些标记转换为字符串。
- 打印结果
Pytorch代码
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
text = r"""
最新文章
- Android基础总结(三)
- 【十大经典数据挖掘算法】Apriori
- 【Windows编程】系列第九篇:剪贴板使用
- 试图加载格式不正确的程序。 (异常来自 HRESULT:0x8007000B)
- Java设计模式-抽象工厂模式(Abstract Factory )
- 三大框架ssh
- spin_lock、spin_lock_irq、spin_lock_irqsave区别【转】
- winform自定义文件程序-- 不允许所请求的注册表访问权(ZSSQL)
- ajax跨域请求--jsonp实例
- 第22题 Rotate List
- PHP GD 库 缩略图 添加水印
- wiringPi库的pwm配置及使用说明
- Mac 下 python 环境问题
- 数据库 价格字段 设置 decimal(8,2),价格为100W,只显示999999.99
- 设计师们做UI设计和交互设计、界面设计等一般会去什么网站呢?
- Django settings介绍
- Monitor WMIExportsToC++Use DiskCleanup bypass UAC
- 解决百度ueditor支持iframe框架页面的视频播放问题
- 「翻译」一篇redis文章引发的翻译——JVM能支持多少线程?
- Mysqldump参数大全(参数来源于mysql5.5.19源码)
热门文章
- Python计算给定日期的周内的某一天
- 《深入理解 Java 虚拟机》读书笔记:类文件结构
- .NET平台编程语言的衰败
- 7-6 jmu_python_最大公约数&;最小公倍数 (10 分)
- 关于地址栏url的一些小结
- vue中的自定义分页插件组件
- Codeforces Round #626 (Div. 2, based on Moscow Open Olympiad in Informatics)
- Ubuntu 系统下如何安装pip3工具
- 【猫狗数据集】pytorch训练猫狗数据集之创建数据集
- Node的require和module.exports