TensorFlow Frontend前端
TensorFlow Frontend前端
TensorFlow前端有助于将TensorFlow模型导入TVM。
Supported versions:
- 1.12 and below
Tested models:
- Inception (V1/V2/V3/V4)
- Resnet (All)
- Mobilenet (V1/V2 All)
- Vgg (16/19)
- BERT (Base/3-layer)
Preparing a Model for Inference准备推理模型
Remove Unneeded Nodes删除不需要的节点
导出过程将删除许多不需要进行推理的节点,但不幸的是会留下一些剩余的节点。应该手动删除的节点:
- Dropout, including Dropout and DropoutWrapper
- Assert
Convert None Dimensions to Constants将无尺寸Dimensions转换为常数
TVM对动态张量形状的支持最少。None应将尺寸替换为常量。例如,模型可以接受带有shape的输入(None,20)。这应转换为的形状(1,20)。应该相应地修改模型,以确保这些形状在整个图形中都匹配。
Export
TensorFlow前端需要冻结的protobuf(.pb)或保存的模型作为输入。不支持检查点(.ckpt)。TensorFlow前端所需的graphdef,可以从活动会话中提取,可以使用TFParser帮助器类提取。
应该导出该模型并进行许多转换,以准备模型进行推理。设置`add_shapes=True`也很重要,因为这会将每个节点的输出形状嵌入到图形中。这是一个给定会话将模型导出为protobuf的函数:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
def export_pb(session):
with tf.gfile.GFile("myexportedmodel.pb", "wb") as f:
inputs = ["myinput1", "myinput2"] # replace with your input names
outputs = ["myoutput1"] # replace with your output names
graph_def = session.graph.as_graph_def(add_shapes=True)
graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs)
graph_def = TransformGraph(
graph_def,
inputs,
outputs,
[
"remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)",
"sort_by_execution_order", # sort by execution order after each transform to ensure correct node ordering
"remove_attribute(attribute_name=_XlaSeparateCompiledGradients)",
"remove_attribute(attribute_name=_XlaCompile)",
"remove_attribute(attribute_name=_XlaScope)",
"sort_by_execution_order",
"remove_device",
"sort_by_execution_order",
"fold_batch_norms",
"sort_by_execution_order",
"fold_old_batch_norms",
"sort_by_execution_order"
]
)
f.write(graph_def.SerializeToString())
Another method is to export and freeze the graph.
Import the Model
Explicit Shape:
确保可以在整个图形中知道形状,将`shape`参数传递给`from_tensorflow`。该词典将输入名称映射到输入形状。
Data Layout
大多数TensorFlow模型以NHWC布局发布。NCHW布局通常提供更好的性能,尤其是在GPU上。该TensorFlow前端可以通过传递参数自动转换模型的数据布局`layout='NCHW'`到`from_tensorflow`。
Best Practices
- 使用静态张量形状代替动态形状(删除`None`尺寸)。
- `TensorArray`目前尚不支持使用静态RNN代替动态RNN。
Supported Ops
- Abs
- Add
- AddN
- All
- Any
- ArgMax
- ArgMin
- AvgPool
- BatchMatMul
- BatchMatMulV2
- BatchNormWithGlobalNormalization
- BatchToSpaceND
- BiasAdd
- BroadcastTo
- Cast
- Ceil
- CheckNumerics
- ClipByValue
- Concat
- ConcatV2
- Conv2D
- Cos
- Tan
- CropAndResize
- DecodeJpeg
- DepthwiseConv2dNative
- DepthToSpace
- Dilation2D
- Equal
- Elu
- Enter
- Erf
- Exit
- Exp
- ExpandDims
- Fill
- Floor
- FloorDiv
- FloorMod
- FusedBatchNorm
- FusedBatchNormV2
- Gather
- GatherNd
- GatherV2
- Greater
- GreaterEqual
- Identity
- IsFinite
- IsInf
- IsNan
- LeakyRelu
- LeftShift
- Less
- LessEqual
- Log
- Log1p
- LoopCond
- LogicalAnd
- LogicalOr
- LogicalNot
- LogSoftmax
- LRN
- LSTMBlockCell
- MatMul
- Max
- MaxPool
- Maximum
- Mean
- Merge
- Min
- Minimum
- MirrorPad
- Mod
- Mul
- Neg
- NextIteration
- NotEqual
- OneHot
- Pack
- Pad
- PadV2
- Pow
- Prod
- Range
- Rank
- RealDiv
- Relu
- Relu6
- Reshape
- ResizeBilinear
- ResizeBicubic
- ResizeNearestNeighbor
- ReverseV2
- RightShift
- Round
- Rsqrt
- Select
- Selu
- Shape
- Sigmoid
- Sign
- Sin
- Size
- Slice
- Softmax
- Softplus
- SpaceToBatchND
- SpaceToDepth,
- Split
- SplitV
- Sqrt
- Square
- SquareDifference
- Squeeze
- StridedSlice
- Sub
- Sum
- Switch
- Tanh
- TensorArrayV3
- TensorArrayScatterV3
- TensorArrayGatherV3
- TensorArraySizeV3
- TensorArrayWriteV3
- TensorArrayReadV3
- TensorArraySplitV3
- TensorArrayConcatV3
- Tile
- TopKV2
- Transpose
- TruncateMod
- Unpack
- UnravelIndex
- Where
- ZerosLike
最新文章
- 前端学HTTP之Web主机托管
- [译] 你该知道的javascript作用域 (javascript scope)(转)
- C#基础
- BizTalk开发系列(九) MAP的连接方法
- 安卓中級教程(2):@InjectView中的對象inject
- Cisco SG300系列交换机划分VLan与普通路由器连接配置
- Bridge 桥模式
- [HDOJ3466]Proud Merchants(贪心+01背包)
- 使用adb devices命令,老是报error:device offline的错误。
- oracle 查看表属主和表空间sql
- DBA查询命令积累——不断更新
- python3学习笔记(0)
- 巨人大哥谈Java中的Synchronized关键字用法
- [LeetCode] Longest Palindromic Subsequence 最长回文子序列
- 恶补web之四:xhtml学习
- Android Studio 3.1.3正式版的新坑。。。
- bzoj4946: [Noi2017]蔬菜 神烦贪心
- 第一章:模型层model layer
- Android——控件AutoCompleteTextView 自动提示
- Python开发【Django】:Admin配置管理
热门文章
- SpringIOC框架简单实现(注解实现)
- Pytorch系列:(四)IO操作
- Portswigger web security academy:Server-side template injection(SSTI)
- 【pytest系列】- pytest测试框架介绍与运行
- String相关介绍
- SQL中那么多函数,Java8为什么还要提供重复的Stream方法,多此一举?
- [DB] Spark Core (1)
- 保存 yum 下载的软件包并制作成本地 yum 源
- BUUCTF(十)[GXYCTF2019]Ping Ping Ping 1
- 配置文件修改java安全级别和站点信息