文章目录:

1 Dataset基类

PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。

先看一下源码:

这里有一个__getitem__函数,__getitem__函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑

其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再讲)。

2 构建Dataset子类

下面我们构建一下Dataset的子类,叫他MyDataset类:

import torch
from torch.utils.data import Dataset,DataLoader class MyDataset(Dataset):
def __init__(self):
self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]])
self.label = torch.LongTensor([1,1,0,0]) def __getitem__(self,index):
return self.data[index],self.label[index] def __len__(self):
return len(self.data)

2.1 Init

  • 初始化中,一般是把数据直接保存在这个类的属性中。像是self.data,self.label

2.2 getitem

  • index是一个索引,这个索引的取值范围是要根据__len__这个返回值确定的,在上面的例子中,__len__的返回值是4,所以这个index会在0,1,2,3这个范围内。

3 dataloader

从上文中,我们知道了MyDataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。

继续上面的代码,我们接着写代码:

mydataloader = DataLoader(dataset=mydataset,
batch_size=1)

我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:

for i,(data,label) in enumerate(mydataloader):
print(data,label)

输出结果是:

意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的

我们稍微修改一下上面的DataLoader的参数:

mydataloader = DataLoader(dataset=mydataset,
batch_size=2,
shuffle=True) for i,(data,label) in enumerate(mydataloader):
print(data,label)

结果是:

可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:

两次结果不同,这是因为shuffle=True,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。

【个人感想】

Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:

  1. 一般标签值应该是Long整数的,所以标签的tensor可以用torch.LongTensor(数据)或者用.long()来转化成Long整数的形式。
  2. 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用to()放到GPU显存上进行GPU加速。
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for i,(data,label) in enumerate(mydataloader):
data = data.to(device)
label = label.to(device)
print(data,label)

看一下输出:

最新文章

  1. Chrome一直提示“adobe flash player 因过期而遭阻止” ,如何解决?
  2. 获取url传参
  3. iOS10 权限配置
  4. chrome调试本地项目, 引用本地javascript文件
  5. SQLServer学习笔记系列8
  6. XCode设置(怎么让代码收缩)
  7. 用 CSS 做轮播图
  8. Php中正则小结(一)
  9. javascript进击(三)简介
  10. 12个强大的Web服务测试工具
  11. SPL的基本使用
  12. 写一个背景渐变的TextView输入框
  13. C++成员变量与函数内存分配
  14. proc文件系统探索 之 根目录下的文件[二]
  15. 笔记之monkey参数(一)
  16. C# 取得上月月头和月尾、上周的第一天和最后一天。
  17. 【BZOJ4259】残缺的字符串
  18. 基于VMware模拟实现远程主机网络通信
  19. linux自旋锁、互斥锁、信号量
  20. 磁共振中的T1, T2 和 T2*的原理和区别

热门文章

  1. 什么是Cookie、Session、Token?
  2. 【AHOI2009】中国象棋 题解(线性DP+数学)
  3. kafka的学习1
  4. C语言学习笔记之原码反码补码
  5. Linxu系统安装PHP详细教程
  6. Python安装工具
  7. Java基础—面向对象特性
  8. C#LeetCode刷题之#172-阶乘后的零(Factorial Trailing Zeroes)
  9. 自动化特征工程—Featuretools
  10. Qt 信号发射部分 undefined reference to错误