pytorch imagenet测试代码
2024-08-22 19:22:49
image_test.py
import argparse
import numpy as np
import sys
import os
import csv
from imagenet_test_base import TestKit
import torch class TestTorch(TestKit): def __init__(self):
super(TestTorch, self).__init__() self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)]
self.truth['keras']['inception_v3'] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)] self.model = self.MainModel.KitModel(self.args.w)
self.model.eval() def preprocess(self, image_path):
x = super(TestTorch, self).preprocess(image_path)
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0).copy()
self.data = torch.from_numpy(x)
self.data = torch.autograd.Variable(self.data, requires_grad = False) def print_result(self, image_name, top1, top5):
predict = self.model(self.data)
predict = predict.data.numpy()
return super(TestTorch, self).print_result(predict, image_name, top1, top5) def print_intermediate_result(self, layer_name, if_transpose=False):
intermediate_output = self.model.test.data.numpy()
super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose) def inference(self, images): with open(images) as fp_images:
images_file = csv.reader(fp_images, delimiter='\n')
top1 = 0.0
top5 = 0.0
image_count = 0
for image_name in images_file:
image_count += 1
image_path = "../data/imagenet/small_imagenet/"+image_name[0]
self.preprocess(image_path)
temp1, temp5 = self.print_result(image_name[0], top1, top5)
top1 = temp1
top5 = temp5
print("top1's accuracy : %f"%(top1/image_count))
print("top5's accuracy : %f"%(top5/image_count))
# self.print_intermediate_result(None, False)
# self.test_truth() def dump(self, path=None):
if path is None: path = self.args.dump
torch.save(self.model, path)
print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
path, self.args.n, self.args.w)) if __name__=='__main__':
tester = TestTorch()
if tester.args.dump:
tester.dump()
else:
tester.inference(tester.args.image)
image_test_base.py:
请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz
执行py文件时,需要终端输入参数:
parser = argparse.ArgumentParser() parser.add_argument('-p', '--preprocess', type=_text_type, help='Model Preprocess Type') # pytorch的测试程序, 这里为image_test.py parser.add_argument('-n', type=_text_type, default='kit_imagenet',
help='Network structure file name.') # 模型结构测试文件 parser.add_argument('-s', type=_text_type, help='Source Framework Type',
choices=self.truth.keys()) # 框架类型:pytorch,tensorflow... parser.add_argument('-w', type=_text_type, required=True,
help='Network weights file name') #模型结构文件 parser.add_argument('--image', '-i',
type=_text_type, help='Test image path.',
default="../data/file_list.txt" #图像路径
) parser.add_argument('-l', '--label',
type=_text_type,
default='../data/val.txt',
help='Path of label.') #测试集类别 parser.add_argument('--dump',
type=_text_type,
default=None,
help='Target model path.') # 转化的目标模型文件的保存路径 parser.add_argument('--detect',
type=_text_type,
default=None,
help='Model detection result path.') # tensorflow dump tag
parser.add_argument('--dump_tag',
type=_text_type,
default=None,
help='Tensorflow model dump type',
choices=['SERVING', 'TRAINING'])
最新文章
- SQL Server 数据库的维护(二)__触发器
- c# json转换实例
- Windows下MongoDB环境搭建
- App性能提升方法
- 【转】六年软件测试感悟-从博彦到VMware
- BZOJ 1570: [JSOI2008]Blue Mary的旅行( 二分答案 + 最大流 )
- 学会了 C 语言真的可以开发出很多东西吗?
- HDU 2671 Can't be easier
- JavaBean--删除操作
- Autofac学习之三种生命周期:InstancePerLifetimeScope、SingleInstance、InstancePerDependency
- 面试题-NSDate\CFAbsoluteTimeGetCurrent\CACurrentMediaTime的区别
- JAVA经典算法40题(原题+分析)之原题
- Java虚拟机三:OutOfMemoryError异常分析
- webstorm快捷键大全(亲自整理)
- LuoguP3674 小清新人渣的本愿 &;&; BZOJ4810: [Ynoi2017]由乃的玉米田
- 补偿接口中循环一直执行sql的问题
- BZOJ4154:[Ipsc2015]Generating Synergy(K-D Tree)
- 怎样自己定义注解Annotation,并利用反射进行解析
- java安全性-引用-分层-解耦
- Jmeter参数化与检查点
热门文章
- Educational Codeforces Round 66 (Rated for Div. 2)
- Spark 系列(十三)—— Spark Streaming 与流处理
- Linux进程间通信—使用共享内存
- js json数据保存到本地
- 用不上索引的sql
- 4.linux下配置Golang的环境变量
- weui中的picker使用js进行动态绑定数据
- GC是如何判断一个对象为";垃圾";的?被GC判断为";垃圾";的对象一定会被回收吗?
- 解决spring-test中Feign问题: No qualifying bean of type 'org.springframework.cloud.openfeign.FeignContext' available
- macOS 10.13允许任何来源开启方法