image_test.py

import argparse
import numpy as np
import sys
import os
import csv
from imagenet_test_base import TestKit
import torch class TestTorch(TestKit): def __init__(self):
super(TestTorch, self).__init__() self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)]
self.truth['keras']['inception_v3'] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)] self.model = self.MainModel.KitModel(self.args.w)
self.model.eval() def preprocess(self, image_path):
x = super(TestTorch, self).preprocess(image_path)
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0).copy()
self.data = torch.from_numpy(x)
self.data = torch.autograd.Variable(self.data, requires_grad = False) def print_result(self, image_name, top1, top5):
predict = self.model(self.data)
predict = predict.data.numpy()
return super(TestTorch, self).print_result(predict, image_name, top1, top5) def print_intermediate_result(self, layer_name, if_transpose=False):
intermediate_output = self.model.test.data.numpy()
super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose) def inference(self, images): with open(images) as fp_images:
images_file = csv.reader(fp_images, delimiter='\n')
top1 = 0.0
top5 = 0.0
image_count = 0
for image_name in images_file:
image_count += 1
image_path = "../data/imagenet/small_imagenet/"+image_name[0]
self.preprocess(image_path)
temp1, temp5 = self.print_result(image_name[0], top1, top5)
top1 = temp1
top5 = temp5
print("top1's accuracy : %f"%(top1/image_count))
print("top5's accuracy : %f"%(top5/image_count))
# self.print_intermediate_result(None, False)
# self.test_truth() def dump(self, path=None):
if path is None: path = self.args.dump
torch.save(self.model, path)
print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
path, self.args.n, self.args.w)) if __name__=='__main__':
tester = TestTorch()
if tester.args.dump:
tester.dump()
else:
tester.inference(tester.args.image)

image_test_base.py:

  请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz

执行py文件时,需要终端输入参数:

 parser = argparse.ArgumentParser()

        parser.add_argument('-p', '--preprocess', type=_text_type, help='Model Preprocess Type')   # pytorch的测试程序, 这里为image_test.py

        parser.add_argument('-n', type=_text_type, default='kit_imagenet',
help='Network structure file name.') # 模型结构测试文件 parser.add_argument('-s', type=_text_type, help='Source Framework Type',
choices=self.truth.keys()) # 框架类型:pytorch,tensorflow... parser.add_argument('-w', type=_text_type, required=True,
help='Network weights file name') #模型结构文件 parser.add_argument('--image', '-i',
type=_text_type, help='Test image path.',
default="../data/file_list.txt" #图像路径
) parser.add_argument('-l', '--label',
type=_text_type,
default='../data/val.txt',
help='Path of label.') #测试集类别 parser.add_argument('--dump',
type=_text_type,
default=None,
help='Target model path.') # 转化的目标模型文件的保存路径 parser.add_argument('--detect',
type=_text_type,
default=None,
help='Model detection result path.') # tensorflow dump tag
parser.add_argument('--dump_tag',
type=_text_type,
default=None,
help='Tensorflow model dump type',
choices=['SERVING', 'TRAINING'])

最新文章

  1. SQL Server 数据库的维护(二)__触发器
  2. c# json转换实例
  3. Windows下MongoDB环境搭建
  4. App性能提升方法
  5. 【转】六年软件测试感悟-从博彦到VMware
  6. BZOJ 1570: [JSOI2008]Blue Mary的旅行( 二分答案 + 最大流 )
  7. 学会了 C 语言真的可以开发出很多东西吗?
  8. HDU 2671 Can't be easier
  9. JavaBean--删除操作
  10. Autofac学习之三种生命周期:InstancePerLifetimeScope、SingleInstance、InstancePerDependency
  11. 面试题-NSDate\CFAbsoluteTimeGetCurrent\CACurrentMediaTime的区别
  12. JAVA经典算法40题(原题+分析)之原题
  13. Java虚拟机三:OutOfMemoryError异常分析
  14. webstorm快捷键大全(亲自整理)
  15. LuoguP3674 小清新人渣的本愿 && BZOJ4810: [Ynoi2017]由乃的玉米田
  16. 补偿接口中循环一直执行sql的问题
  17. BZOJ4154:[Ipsc2015]Generating Synergy(K-D Tree)
  18. 怎样自己定义注解Annotation,并利用反射进行解析
  19. java安全性-引用-分层-解耦
  20. Jmeter参数化与检查点

热门文章

  1. Educational Codeforces Round 66 (Rated for Div. 2)
  2. Spark 系列(十三)—— Spark Streaming 与流处理
  3. Linux进程间通信—使用共享内存
  4. js json数据保存到本地
  5. 用不上索引的sql
  6. 4.linux下配置Golang的环境变量
  7. weui中的picker使用js进行动态绑定数据
  8. GC是如何判断一个对象为"垃圾"的?被GC判断为"垃圾"的对象一定会被回收吗?
  9. 解决spring-test中Feign问题: No qualifying bean of type 'org.springframework.cloud.openfeign.FeignContext' available
  10. macOS 10.13允许任何来源开启方法