All-in-one 的Serving分析
2024-10-07 14:53:53
export_func.export(model, sess, signature_name=mission, version=fold + 1)
def export(model, sess, signature_name, export_path=root_path + '/all_in_one/demo/exported_models/', version=1):
# export path
export_path = os.path.join(os.path.realpath(export_path), signature_name, str(version))
print('Exporting trained model to {} ...'.format(export_path)) builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# Build the signature_def_map.
classification_w = tf.saved_model.utils.build_tensor_info(model.w)
# classification_is_training = tf.saved_model.utils.build_tensor_info(model.is_training)
classification_dropout_keep_prob_mlp = tf.saved_model.utils.build_tensor_info(
model.dropout_keep_prob_mlp)
# score
classification_outputs_scores = tf.saved_model.utils.build_tensor_info(model.y) classification_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={tf.saved_model.signature_constants.CLASSIFY_INPUTS: classification_w},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
classification_outputs_scores
},
method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME) # 'tensorflow/serving/classify' prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input_plh': classification_w, 'dropout_keep_prob_mlp':
classification_dropout_keep_prob_mlp,
# 'is_training': classification_is_training
},
outputs={'scores': classification_outputs_scores},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) # 'tensorflow/serving/predict'
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
signature_name: prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: classification_signature,
})
builder.save()
在signature_def_map中定义了两个,一个是自己设计的别名,一个是默认的。
定义一个解析类。
model_name 是启动服务时明确的model_name
signature_name是在signature_def_map中自己设计的别名对应的输入输出之类的。
def classify(self, sents):
self.sents=self.sents2id(sents)
hostport = '192.168.31.186:6000'
# grpc
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
# build request
request = predict_pb2.PredictRequest()
request.model_spec.name = self.model_name
request.model_spec.signature_name = self.signature_name
request.inputs['input_plh'].CopyFrom(
tf.contrib.util.make_tensor_proto(self.sents, dtype=tf.int32))
request.inputs['dropout_keep_prob_mlp'].CopyFrom(
tf.contrib.util.make_tensor_proto(1.0, dtype=tf.float32))
model_result = stub.Predict(request, 60.0)
model_result = np.array(model_result.outputs['scores'].float_val)
model_result = [model_result.tolist()][0]
index, _ =max(enumerate(model_result), key=operator.itemgetter(1))
if index>0:
label = self.label_dict[index-1]
else:
label = ""
# print("index:{}\tlabel:{}".format(index, label))
if self.encode == "part" :
if label:
label=self.part[label]
else:
label = "凌晨"
if self.encode == "type" :
if label:
label=self.type[label]
else:
label = "录像"
if self.encode == "door" and label:
label = self.gate[label] return label
最新文章
- js 的Location对象
- js 滚动的文字(走马灯)
- CSS3新特性学习
- js判断是否为ie6以外的浏览器,若是,则调用相应脚本
- Java 泛型和通配符解惑
- Sprint(第七天11.20)
- centos7下安装vsftpd配置
- 263. Ugly Number
- 中兴电信光纤猫F612管理员密码获取方法
- 【贪心】 BZOJ 3252:攻略
- 汇编 db,dw,dd的区别
- Light OJ 1095 Arrange the Numbers(容斥)
- oc基础 不可变字符串的创建和使用
- Leetcode016 3Sum Closest
- js跳转页面的几种方式
- eclipse工程当中的.classpath 和.project文件什么作用?
- 【转载】C#生成图片的缩略图
- gcc 6.0编译opencv出错
- RN热更新
- 2016年蓝桥杯B组C/C++省赛(预选赛)题目解析
热门文章
- zookeeper+springboot+dubbo简单实现
- 再整理:Visual Studio Code(vscode)下的通用C语言环境搭建
- 基于SkyWalking的分布式跟踪系统 - 微服务监控
- continue和break在while中用法
- [考试反思]1001csp-s模拟测试(b):逃离
- JS- 封装、继承、多态
- N42期-qq-林友埙-第一周作业
- windows下安装nginx和基本配置
- [ZJOI2013]K大数查询——整体二分
- 转载]OK6410之tftp下载内核,nfs挂载文件系统全过程详解[转]