翻译:pytorch数据加载和处理

在解决任何机器学习问题时,都需要花费大量的精力来准备数据。PyTorch提供了许多工具来简化数据加载,希望能使代码更具可读性。在这篇教程里,我们将看看如何从非平凡的数据集中加载和预处理/扩增数据。

为了运行这篇教程,请确保一下模块已经安装了:

  • scikit-image:为了图像的输入输出和转化
  • pandas:为了跟容易解析csv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# 忽略警告
import warnings
warnings.filterwarnings('ignore')

plt.ion() # 交互模式

我们将要处理的数据集是人脸姿态。这以为着一张脸被标注成这样:

@import “landmarked_face2.png”

总共有68个标注点标注在每张脸上。

  • 注意:
    这里下载数据集,图片在‘data/faces/’目录下。这个数据集是基于imagenet中被标记为‘face’的一些图片通过应用优秀的dlib的姿态估计生成。

数据集带有一个csv文件,其中带有类似于下面的注释:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们快速读取csv并且得到(N, 2)数组的标注,N是指的标注点的个数。

1
2
3
4
5
6
7
8
9
10
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]

让我们写一个简单的展示一张图片和它的标注点的帮助函数,使用它来显示一个样本。

1
2
3
4
5
6
7
8
9
10
def show_landmarks(image, landmarks):
"""显示一张带有标注点的函数"""

plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s = 10, marker = '.', c = 'r')
plt.pause(0.001) # 暂停一会等待更新

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)), landmarks)
plt.show()

@import “sphx_glr_data_loading_tutorial_001.png”

Dataset类

torch.utils.data.Dataset是表示数据集的抽象类。你的自定义数据集应该继承Dataset并覆盖以下方法:

  • __len__以便len(dataset)返回dataset的大小
  • __getitem__来支持索引操作,像是dataset[i]用来获得第i个样本

让我们创建一个我们脸部表注数据集的dataset吧。我们将在__init__中读取csv但是留在__getitem__中读取图片。这是为了内存效率因为所有的图片不是一次储存在内存中,而是按需要储存。

我们数据集的赝本将会是字典{'image':image, 'landmarks':landmarks}。我们数据集将获得一个选填参数transform以便对样本进行所有必要的处理。我们将在下一个章节看到transform的有效性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class FaceLandmarksDataset(Dataset):

def __init__(self, csv_file, root_dir, transform = None):
"""
参数:
csv_file(string):csv file的路径
root_dir(string):所有图片的目录
transform(callable, 选填):被应用到样本上的transforms

"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform

def __len__(self):

return len(self.landmarks_frame)

def __getitem__(self, idx):

if torch.is_tensor(idx):
idx = idx.tolist()

img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
images = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image':image, 'landmarks':landmarks}

if self.transform:
sample = self.transform(sample)

return sample

让我们举例使用这个类并且在迭代这个数据集。我们将打印出前四个样本的大小和它们的标注点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
face_dataset = FaceLandmarksDataset(csv_file = 'data/faces/face_landmarks.csv', root_dir = 'data/faces/)

fig = plt.figure()

for i in range(len(face_dataset)):
sample = face_dataset[i]

print(i, sample['image'].shape, sample['landmarks'].shape)

ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample # {}'.format(i))
ax.axis('off')
show_landmarks(**sample)

if i == 3:
plt.show()
break

———————————————感谢阅读———————————————

欢迎收藏访问我的博客 知乎 掘金 简书 知乎

贰三 wechat
欢迎扫描二维码订阅我的公众号!