transformer模型转torchscript格式
2024-08-27 03:14:30
from transformers import BertModel, BertTokenizer, BertConfig
import torch enc = BertTokenizer.from_pretrained("bert-base-uncased") # 输入文本tokenize
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text) # 将一个token置为mask
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 创建虚拟输入
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors] # 初始化模型时将torchscript参数置为True
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True) # 初始化模型
model = BertModel(config) # 模型置为eval模式
model.eval() # 也可以从pretrained初始化模型
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) # 创建trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt") # 加载模型
loaded_model = torch.jit.load("traced_model.pt")
loaded_model.eval() all_encoder_layers, pooled_output = loaded_model(dummy_input) # 使用traced model进行推理
traced_model(tokens_tensor, segments_tensors)
最新文章
- 教你9个提升 Wordpress 网站安全性的方法
- typedef 与指针、多维数组
- QT QDateTime类、QTimer类
- PHP如何释放内存之unset销毁变量并释放内存详解
- SVN迁移到Git的过程(+ 一些技巧
- 针对 SQL Server 2008 在Windows Server 2008上的访问配置 Windows 防火墙
- 从0开始学Java——JSP&;Servlet——如何在Eclipse中配置Web容器为tomcat
- ActionBarCompat
- android147 360 程序锁fragment
- Spring在代码中获取bean的几种方式(转:http://www.dexcoder.com/selfly/article/326)
- SVProgressHUD的使用
- smartforms换页,
- 从头开始学JavaScript (十二)——Array类型
- (转)java提高篇(四)-----理解java的三大特性之多态
- Beta版本冲刺前期计划及安排
- 四、Jedis操作Redis
- vim 学习笔记系列(前言)
- MyBatis-parameterType 取出入参值
- C#中Post请求的两种方式发送参数链和Body的
- Docker学习笔记之Docker的Build 原理