1)前言

虽然torchvision.datasets中已经封装了好多通用的数据集,但是我们在使用Pytorch做深度学习任务的时候,会面临着自定义数据库来满足自己的任务需要。如我们要训练一个人脸关键点检测算法,提供的训练数据标注如下形式,存在CSV文件中:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

在本次教程中,我们需要用到两个额外的包:

  • scikit-image: 用于图片io和转换
  • pandas: 用于解析csv文件

首先学习如何使用pandas库解析csv文件

import pandas as pd
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

2)自定义数据库

torch.utils.data.Dataset是一个表示数据库的抽象类,自定义数据库需要继承这个类,并且重写其以下方法:

__len__ :返回数据库的大小.
__getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本

以下创建人脸特征点检测的数据库。我们将在__init__中解析csv文件,而在__getitem__中读取图片。这样可以在需要图片是才加载,内存效率高。此外,我们还可以先将数据集封装成lmdb数据库,读取速度更快。

import torch.utils.data.Dataset as Dataset
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): 到达标注文件cvs的路径.
root_dir (string): 所有图片的根目录.
transform (callable, optional): (可选参数)对每一个样本进行转换.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform def __len__(self):
return len(self.landmarks_frame) def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0]) #第idx条数据的第一个字段,即文件名称
image = io.imread(img_name) #读取图像数据
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() #读取第idx条数据的第二个字段及其之后的所有字段,即所有关键点的坐标。然后转成矩阵形式
landmarks = landmarks.astype('float').reshape(-1, 2) #将矩阵reshape成n行两列矩阵
sample = {'image': image, 'landmarks': landmarks} #封装数据 if self.transform:
sample = self.transform(sample) #数据转换 return sample #返回数据

注:__getitem__每次只返回一个条数据,至于batch的封装可以在DataLoader中设置batchsize,至于读取速度可以设置num_worker。

最新文章

  1. Webpack--自学笔记
  2. Twitter Bootstrap
  3. Java多线程系列--“基础篇”06之 线程让步
  4. [MAC] Mac OS X下快速复制文件路径的方法
  5. Linux权限问题
  6. js基础之动画(二)
  7. js(jQuery)获取时间的方法及常用时间类搜集
  8. Comet、SSE、Web Socket
  9. AWR报告导出的过程报ORA-06550异常
  10. Linux中的模式转换
  11. NSMakeRange基础函数应用
  12. VK Cup 2012 Qualification Round 1---C. Cd and pwd commands
  13. javascript的insertBefore、insertAfter和appendChild简单介绍
  14. 51单片机I/O口直接输入输出实例(附调试及分析过程)
  15. Mysql net start mysql启动,提示发生系统错误 5 拒绝访问,原因所在以及解决办法
  16. lbp特征提取(等价模式)
  17. Webdriver之API详解(3)
  18. vue-cli卸载旧版,再重新安装后还显示的是旧的版本
  19. MAVEN项目环境搭建
  20. echarts 滚动条 缩放

热门文章

  1. AtCoder Grand Contest 003
  2. FourAndSix: 2.01靶机入侵
  3. Windows + Ubuntu下JDK与adb/android环境变量配置完整教程
  4. 解题:SCOI 2011 糖果
  5. k最近邻算法(kNN)
  6. ASP.NET MVC 3 常用
  7. vi的一些使用技巧
  8. 安装JDK、Tomcat、Maven’详细步骤
  9. c++操作mysql入门详解
  10. HOJ 13102 Super Shuttle (圆的反演变换)