数据导入与处理

来自这里

在解决任何机器学习问题时,都需要在处理数据上花费大量的努力。PyTorch提供了很多工具来简化数据加载,希望使代码更具可读性。在本教程中,我们将学习如何从繁琐的数据中加载、预处理数据或增强数据。

开始本教程之前,请确认你已安装如下Python包:

  • scikit-image:图像IO操作和格式转换
  • pandas:更方便解析CSV

我们接下来要处理的数据集是人脸姿态。这意味着人脸的注释如下:

总之,每个面部都有68个不同标记点。

可以从这里下载数据集,并将其解压后存放到目录‘data/faces/’。

数据集来自带有面部注释的CSV文件,文件内容类似以下格式:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

接下来我们快速读取CSV文件,并从(N,2)数组中获取注释,N表示标记数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]

现在我们写一个简单的帮助函数:展示图片和它的标记,用它来展示样本。

def show_landmarks(image,landmarks):
'''
展示带标记点的图像
'''
plt.imshow(image)
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
plt.pause(10) plt.figure()
show_landmarks(io.imread(os.path.join('data/faces',img_name)),landmarks)
plt.show()

数据集类(Dataset class)

torch.utils.data.Dataset是一个表示数据集的抽象类。你自定义的数据集应该继承Dataset并重写以下方法:

  • len 这样len(dataset)是可以返回数据集的大小
  • getitem 支持索引操作,比如dataset[i]来获取第i个样本。

现在我们来实现我们的面部标记数据集类。我们将在__init__中读取CSV,然后再__getitem__中读取图像。这样可以高效利用内存,因为所有的图像并不是都存在在内存中,而是按需读取。

我们数据集的样本是字典格式的:{'image':image,'landmarks':landmarks}。我们的数据集将采用可选参数transform,以便任何必要的处理都可以被应用在样本上。在下一节中我们会看到transform的用途。

class FaceLandmarksDataset(Dataset):
'''
Face Landmarks Dataset
''' def __init__(self,csv_file,root_dir,transform=None):
'''
param csv_file(string): 带注释的CSV文件路径
param root_dit(string): 存储图像的路径
param transform(callable,optional): 被应用到样本的可选transform操作
'''
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform def __len__(self):
return len(self.landmarks_frame) def __getitem__(self,idx):
img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx,1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1,2)
sample = {'image':image,'landmarks':landmarks} if self.transform:
sample = self.transform(sample) return sample

现在我们实例化这个类,并且迭代输出部分样本。我们打印输出前4个样本并展示它们的标记。

face_dataset = FaceLandmarksDataset(
csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/') fig = plt.figure() for i in range(len(face_dataset)):
sample = face_dataset[i] print(i, sample['image'].shape, sample['landmarks'].shape) ax = plt.subplot(1, 4, i+1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample) if i == 3:
plt.show()
break

输出:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Transforms(转换)

从上面的例子可以看出这些样本的尺寸并不一致。大多数神经网络都期望图像的尺寸是固定的。这样的话,我们就需要一些处理代码。接下来我们创建三个变换函数:

  • Rescale:缩放图像
  • RandomCrop:随机裁剪图像。这是数据扩充。
  • ToTensor:将numpy图像转为torch图像(我们需要交换轴)。

我们将以类而不是简单的函数的方式来实现它们,这样就不需要在每次调用时都传递转换需要的参数。这样我们只需要实现__call__方法,需要的话还可以实现__init__方法。然后我们可以按如下的方式使用:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

下面展示如何将这些转换同时应用在图像和标记点。

class Rescale(object):
'''
按给定的尺寸缩放图像 param output_size (tuple or int): 目标输出尺寸。如果是tuple,输出为匹配的输出尺寸;如果是int,则匹配较小的图像边缘,保证相同的长宽比例。
''' def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size*h/w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size*w/h
else:
new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) landmarks = landmarks*[new_w/w, new_h/h]
return {'image': img, 'landmarks': landmarks} class RandomCrop(object):
'''
随机裁剪图像 param output_size (tuple or int): 目标输出尺寸。如果是int,正方形裁剪
''' def __init__(self, output_size):
assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2]
new_h, new_w = self.output_size top = np.random.randint(0, h-new_h)
left = np.random.randint(0, w-new_w) image = image[top:top+new_h, left:left+new_w] landmarks = landmarks - [left, top] return {'image': image, 'landmarks': landmarks} class ToTensor(object):
'''
将ndarrays格式样本转换为Tensors
''' def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks'] image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}

组合变换

现在,我们在样本上应用转换。

比如我们想将图片的短边缩放为256然后在随机裁剪出一个224的正方形,那么我们将用到RescaleRandomCroptorchvision.transforms.Compost可以帮助我们完成上述组合操作。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(224)]) fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample) ax = plt.subplot(1, 3, i+1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()

遍历数据集

接下来我们将上面的代码整合起来,创建一个带有组合变换的数据集。综上所述,每次采样该数据集时:

  • 从文件中动态读取图像
  • 转换应用到读取的图像上
  • 由于其中一种转换是随机的,因此数据在抽样时得到了扩充

我们可以是像之前一样用for i in range循环遍历创建的数据集:

transformed_dataset = FaceLandmarksDataset(
csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()])
) for i in range(len(transformed_dataset)):
sample = transformed_dataset[i] print(i, sample['image'].size(), sample['landmarks'].size()) if i == 3:
break

输出:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,只是使用建档的for训练遍历数据,我们将丢失很多特征。尤其是我们丢失了:

  • 批量处理数据
  • 移动数据
  • 使用multiprocessing并行加载数据

torch.utils.data.DataLoader是一个提供了所有这些功能的迭代器。接下来使用的参数是明朗的。一个有趣的参数是collate_fn。你可以使用collate_fn指定需要如何对样本进行批量处理。然而,默认的collate足够胜任大多数使用场景。

dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4) def show_landmarks_batch(sample_batched):
'''
批量展示样本
'''
images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
batch_size = len(sample_batched)
im_size = images_batch.size(2)
grid_border_size = 2 grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0))) for i in range(batch_size):
plt.scatter(
landmarks_batch[i, :, 0].numpy() + i * im_size +
(i+1)*grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r'
)
plt.title('Batch from dataloader') for i_batch,sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size()) if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break

输出:

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后续:torchvision

在本教程中,我们了解了如何实现并使用数据集、转换和数据导入。torchvision包提供了一些常用的数据集和转换。你甚至可能不需要编写自定义的类。在torchvision中最常用的数据集是ImageFolder。它假设图像的组织方式如下所示:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

‘ants’、‘bees’等等都是类的标签。在PIL.Image上操作的类似常用的转化,如RandomHorizontalFlipScale,都是可用的。你可以使用它们来编写想下面的数据导入:

import torch
from torchvision import transforms, datasets data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)

最新文章

  1. 剖析 HTTP 协议
  2. JavaScript学习链接
  3. realloc的使用误区
  4. JStrom的zk数据
  5. POJ 2823 Sliding Window
  6. Java常用锁机制简介
  7. OK335xS LAN8710 phy driver hacking
  8. const,readonly,static
  9. Android Service 生命周期和使用注意项
  10. utf8_general_ci 、utf8_general_cs和utf8_bin的区别
  11. JavaScript 比量 Chrome 核心 360 浏览器(关闭和技巧)
  12. eclipse配置tomcat及修改tomcat默认根目录
  13. MapReduce过程详解及其性能优化
  14. BZOJ1177 [Apio2009]Oil 二维前缀和 二维前缀最值
  15. linux --- 3 vim 网络 用户 权限 软连接 压缩 定时任务 yum源
  16. ayit-#41. 因数的个数-数论
  17. RAMOS_XP制作教程
  18. js 获取 this 的属性 obj[0].getAttribute
  19. js之function
  20. django的数据库操作

热门文章

  1. 洛谷 P3371 【模板】单源最短路径(弱化版)(dijkstra邻接链表)
  2. Linux应用编程之lseek详解
  3. SLAM领域资源链接
  4. 2017年3月16工作日志【mysql更改字段参数、java8 map()调用方法示例】
  5. RxJava的简单使用
  6. Apache Commons Lang之日期时间工具类
  7. vncserver
  8. 关于tomcat报错记录
  9. 17.3.13---sys.argv[]用法
  10. shell脚本中的条件测试if中的-z到-d的意思