在解决任何机器学习问题时,都需要花费大量的精力来准备数据。PyTorch提供了许多工具来简化数据加载,希望能使代码更具可读性。在这篇教程里,我们将看看如何从非平凡的数据集中加载和预处理/扩增数据。
为了运行这篇教程,请确保一下模块已经安装了:
scikit-image
:为了图像的输入输出和转化pandas
:为了跟容易解析csv
1 | from __future__ import print_function, division |
我们将要处理的数据集是人脸姿态。这以为着一张脸被标注成这样:
@import “landmarked_face2.png”
总共有68个标注点标注在每张脸上。
数据集带有一个csv文件,其中带有类似于下面的注释:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
让我们快速读取csv并且得到(N, 2)数组的标注,N是指的标注点的个数。
1 | landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') |
输出:
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]
让我们写一个简单的展示一张图片和它的标注点的帮助函数,使用它来显示一个样本。
1 | def show_landmarks(image, landmarks): |
@import “sphx_glr_data_loading_tutorial_001.png”
Dataset类
torch.utils.data.Dataset
是表示数据集的抽象类。你的自定义数据集应该继承Dataset
并覆盖以下方法:
__len__
以便len(dataset)
返回dataset的大小__getitem__
来支持索引操作,像是dataset[i]
用来获得第i个样本
让我们创建一个我们脸部表注数据集的dataset吧。我们将在__init__
中读取csv但是留在__getitem__
中读取图片。这是为了内存效率因为所有的图片不是一次储存在内存中,而是按需要储存。
我们数据集的赝本将会是字典{'image':image, 'landmarks':landmarks}
。我们数据集将获得一个选填参数transform
以便对样本进行所有必要的处理。我们将在下一个章节看到transform
的有效性。
1 | class FaceLandmarksDataset(Dataset): |
让我们举例使用这个类并且在迭代这个数据集。我们将打印出前四个样本的大小和它们的标注点。
1 | face_dataset = FaceLandmarksDataset(csv_file = 'data/faces/face_landmarks.csv', root_dir = 'data/faces/) |