我们在《torch.utils.data.DataLoader与迭代器转换》中介绍了如何使用Pytorch内置的数据集进行论文实现,如torchvision.datasets。下面是加载内置训练数据集的常见操作:

from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize
RAW_DATA_PATH = './rawdata'
transform = Compose(
[ToTensor(),
Normalize((0.1307,), (0.3081,))
]
)
train_data = FashionMNIST(
root=RAW_DATA_PATH,
download=True,
train=True,
transform=transform
)

这里的train_data做为dataset对象,它拥有许多熟悉,我们可以通过以下方法获取样本数据的分类类别集合、样本的特征维度、样本的标签集合等信息。

classes = train_data.classes
num_features = train_data.data[0].shape[0]
train_labels = train_data.targets print(classes)
print(num_features)
print(train_labels)

输出如下:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])

但是,我们常常会在训练集的基础上拆分出验证集(或者只用部分数据来进行训练)。我们想到的第一个方法是使用torch.utils.data.random_splitdataset进行划分,下面我们假设划分10000个样本做为训练集,其余样本做为验证集:

from torch.utils.data import random_split
k = 10000
train_data, valid_data = random_split(train_data, [k, len(train_data)-k])

注意我们如果打印train_datavalid_data的类型,可以看到显示:

<class 'torch.utils.data.dataset.Subset'>

已经不再是torchvision.datasets.mnist.FashionMNIST对象,而是一个所谓的Subset对象!此时Subset对象虽然仍然还存有data属性,但是内置的targetclasses属性已经不复存在,比如如果我们强行访问valid_datatarget属性:

valid_target = valid_data.target

就会报如下错误:

'Subset' object has no attribute 'target'

但如果我们在后续的代码中常常会将拆分后的数据集也默认为dataset对象,那么该如何做到代码的一致性呢?

这里有一个trick,那就是以继承SubSet类的方式的方式定义一个新的CustomSubSet类,使新类在保持SubSet类的基本属性的基础上,拥有和原本数据集类相似的属性,如targetsclasses等:

from torch.utils.data import Subset
class CustomSubset(Subset):
'''A custom subset class'''
def __init__(self, dataset, indices):
super().__init__(dataset, indices)
self.targets = dataset.targets # 保留targets属性
self.classes = dataset.classes # 保留classes属性 def __getitem__(self, idx): #同时支持索引访问操作
x, y = self.dataset[self.indices[idx]]
return x, y def __len__(self): # 同时支持取长度操作
return len(self.indices)

然后就引出了第二种划分方法,即通过初始化CustomSubset对象的方式直接对数据集进行划分(这里为了简化省略了shuffle的步骤):

import numpy as np
from copy import deepcopy
origin_data = deepcopy(train_data)
train_data = CustomSubset(origin_data, np.arange(k))
valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)

注意,CustomSubset类的初始化方法的第二个参数indices为样本索引,我们可以通过np.arange()的方法来创建。

然后,我们再访问valid_data对应的classestarges属性:

print(valid_data.classes)
print(valid_data.targets)

此时,我们发现可以成功访问这些属性了:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 0, 0, ..., 3, 0, 5])

当然,CustomSubset的作用并不只是添加数据集的属性,我们还可以自定义一些数据预处理操作。我们将类的结构修改如下:

class CustomSubset(Subset):
'''A custom subset class with customizable data transformation'''
def __init__(self, dataset, indices, subset_transform=None):
super().__init__(dataset, indices)
self.targets = dataset.targets
self.classes = dataset.classes
self.subset_transform = subset_transform def __getitem__(self, idx):
x, y = self.dataset[self.indices[idx]] if self.subset_transform:
x = self.subset_transform(x) return x, y def __len__(self):
return len(self.indices)

我们可以在使用样本前设置好数据预处理算子:

from torchvision import transforms
valid_data.subset_transform = transforms.Compose(\
[transforms.RandomRotation((180,180))])

这样,我们再像下列这样用索引访问取出数据集样本时,就会自动调用算子完成预处理操作:

print(valid_data[0])

打印结果缩略如下:


(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)

最新文章

  1. C#调用vbs脚本实现Windows版Siri
  2. 归并求逆序数(逆序对数) &amp;&amp; 线段树求逆序数
  3. 企业信息系统——SCM
  4. SAP 常用函数
  5. OS X 在Cisco无线环境下丢包分析 part 1
  6. XML中如何使用schema
  7. 分布式日志收集系统--Chukwa
  8. OSG Win7 + VS2015 编译
  9. BTREE与HASH的区别
  10. 【Beta】Phylab2.0: Postmortem
  11. CTabCtrl - 如何使用TabCtrl控件
  12. Android 在一个程序中启动另一个程序
  13. 基于visual Studio2013解决C语言竞赛题之0404循环求和
  14. Android常用开源项目
  15. SSH是什么?Linux如何修改SSH端口号?
  16. atnodes命令+sort+uniq统计特征信息到结果文件
  17. noip2013Day2T3-华容道【一个蒟蒻的详细题解】
  18. 【Android Developers Training】 106. 创建并检测地理围栏
  19. angular4.0快速import依赖路径
  20. 2.计算机组成-数字逻辑电路 门电路与半加器 异或运算半加器 全加器组成 全加器结构 反馈电路 振荡器 存储 D T 触发器 循环移位 计数器 寄存器 传输门电路 译码器 晶体管 sram rom 微处理 计算机

热门文章

  1. Linux系统使用crt登录之后如何显示横幅消息
  2. css 垂直居中技巧
  3. 解压安装Cacti在apache中的补充
  4. 【Azure 应用服务】一个 App Service 同时部署运行两个及多个 Java 应用程序(Jar包)
  5. test_1 计算字符串最后一个单词的长度,单词以空格隔开
  6. kafka学习笔记(六)kafka的controller模块
  7. Pytorch之Spatial-Shift-Operation的5种实现策略
  8. 开发 IDEA Plugin 引入探针,基于字节码插桩获取执行SQL
  9. 【C++】字符串处理
  10. deepin20体验