pytorch学习:准备自己的图片数据
2024-08-26 20:28:08
图片数据一般有两种情况:
1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。
2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。
针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:
一、所有图片放在一个文件夹内
这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。
先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:
import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test)) f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()
经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:
前期工作就装备好了,接着就进入正题了:
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image def default_loader(path):
return Image.open(path).convert('RGB') class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label def __len__(self):
return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。
二、不同类别的图片放在不同的文件夹内
同样先准备数据,这里以flowers数据集为例,下载:
http://download.tensorflow.org/example_images/flower_photos.tgz
花总共有五类,分别放在5个文件夹下。大致如下图:
我的路径是d:/flowers/.
数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder
import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
) print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size()) show_batch(batch_x)
plt.axis('off')
plt.show()
就是这样。
最新文章
- iOS Block界面反向传值
- MongoDB数据库用户名和密码的设置
- spring aop实现
- CentOS安装Hypernetes相关问题解法
- IAR MSP430如何生成烧写文件
- Nexus手动更新索引
- json对象与字符串互转
- OpenMp之sections用法
- oracle 导入数据时提示只有 DBA 才能导入由其他 DBA 导出的文件
- PHP的MySQL扩展:MySQL数据库概述
- 关于RGB转换YUV的探讨与实现
- HDU 1199 - Color the Ball 离散化
- 发送cookie
- 在idea的maven项目使用el或jstl表达式
- 笔记12 注入AspectJ切面
- eclipse安装使用fat打jar包
- Arduino IDE for ESP8266 (0) 官方API
- Druid参考配置
- 【cs229-Lecture12】K-means算法
- 《Spring 2之站立会议3》
热门文章
- Jmeter3.2默认自带的HTML报告
- 南京邮电大学//bugkuCTF部分writeup
- 什么是nrm
- 微服务框架——SpringCloud(二)
- bzoj1124_枪战_基环树
- 平时作业五 Java
- Android SDK提供的常用控件Widget “常用控件”“Android原生”
- FTP连接虚拟主机响应220 Welcome to www.net.cn FTP service. (解决的一个问题)
- jQuery 动态绑定插件livequery的用法
- 异步使用委托delegate --- BeginInvoke和EndInvoke方法