python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发、低延迟的要求,存在需要用c++接口的情况。除了将模型导出为ONNX外,pytorch1.0给出了新的解决方案:pytorch 训练模型 - 通过torch script中间脚本保存模型 -- C++加载模型。最近工作需要尝试做了转换,总结一下步骤和遇到的坑。

用torch script把torch模型转成c++接口可读的模型有两种方式:trace && script. trace比script简单,但只适合结构固定的网络模型,即forward中没有控制流的情况,因为trace只会保存运行时实际走的路径。如果forward函数中有控制流,需要用script方式实现。

trace顾名思义,就是沿着数据运算的路径走一遍,官方例子:


import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

script稍复杂,主要改三处:

1. Model由之前继承 nn.Model 改为继承 torch.jit.ScriptModule

2. forward函数前加 @torch.jit.script_method

3. 其他需要调用的函数前加 @torch.jit.script

踩过的坑&&解决方法:

A. torch script默认函数或方法的参数都是Tensor类型的,如果不是需要说明,不然调用非Tensor参数时会报类型不符的编译错误。

python3可以直接:

def example_func(param_1: Tensor, param_2: int, param_3: List[int]):

python2需要用type注释:

def example_func(param_1, param_2, param_3):

#type: (Tensor, int, List[int]) -> Tensor

B. model的方法中forward加@torch.jit.script_method, __init__函数不用

C. 前面说过,torch scrip支持的函数是pytorch的子集,意味着有一部分函数不支持,例如: not boolean,pass, List的切片赋值,CPU和GPU切换的value.to( ), 需要想办法绕过去。看github上讨论区说新版好像已经支持not操作了,没有验证。

结论:pytorch 1.0目前的预览版还有比较多优化的空间,至少是在torch script支持的函数集合上,不建议使用,等稳定版发布再看看吧。

  

原创内容,转载请注明出处。

参考资料:

https://pytorch.org/docs/master/jit.html

https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html

最新文章

  1. 【gulp】工作中的实战
  2. [ASP.NET Core] Static File Middleware
  3. 苹果开发者账号申请时报错提示错误:Legal Entity Name
  4. Erlang数据类型的表示和实现(4)——boxed 对象
  5. js模板
  6. java环境中基于jvm的两大语言:scala,groovy
  7. 监听mysql是否挂了
  8. maven主仓库中找不到restlet的解决办法
  9. PrintWriter 和 BufferedWriter 写入文件.
  10. EPZS搜索过程
  11. expr的简单应用
  12. HDU 1885 Key Task 国家压缩+搜索
  13. Linux下hosts、host.conf、resolv.conf
  14. Spring 之 示例(Java之负基础实战)
  15. Java通过链表实现队列
  16. 安装python的第三方Pillow库
  17. java中java.exe,javac 在editplus中的配置
  18. Luogu5289 十二省联考2019字符串问题(后缀数组+拓扑排序+线段树/主席树/KDTree)
  19. [No000014A]Linux简介与shell编程
  20. 在使用springMVC时,页面报的404异常

热门文章

  1. 2.Diango学习
  2. qml: 打包 和 发布
  3. C++ MFC------ 快捷键
  4. websocket实现简单的通信
  5. bzoj1027 状压dp
  6. org.apache.catalina.LifecycleException: Failed to start component [StandardEngine[Catalina].Standard
  7. Xstart Insatll And Usage
  8. 5句话搞定ES5作用域
  9. win7安装linux CentOS7双系统实践
  10. bash guide