MNISTとFashion-MNISTのPyTorchのデータローダ

概要

MNISTとFashion-MINISTを組み合わせて学習してみたかったが、既存のデータセットだと多分無理っぽかったので、自作してみた。(車輪のなんとやらではないという言い訳)

ここでは組み合わせるところまでは書かず、一般的なMNISTのローダをどのように書けばよいかが分かるようなものを記載したいと思う。

もし同じような動機に駆られたら参考にしてほしい。

ちなみに、単純にMNISTをPytorchで読み込むだけだったら、以下でいける。とても簡単。

train = torch.utils.data.DataLoader(MNIST('data', train=False))

準備

MNISTおよびFashion-MNISTのバイナリデータ(gzipで圧縮済み)をダウンロードしてきて、 適当なディレクトリに入れておく。 自分の場合は/share/datasets/mnist/data/の下に入れておいた。

MNISTは以下のページにダウンロード元が記載されている。

http://yann.lecun.com/exdb/mnist/

Fashion-MNISTは以下のページにダウロード元が記載されている。

https://github.com/zalandoresearch/fashion-mnist

データの中身

データの中身はバイナリになっている。

ラベルは1バイトにつき、1つのラベル(0-10)が入っている。(たしか)

画像の場合は1バイトが1ピクセルに対応して、0-255までの値が入っている。(多分)

ソースコードと解説

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import gzip
import os
import numpy as np
from PIL import Image


class MNIST(Dataset):
    def __init__(self, kind="MNIST", train=True):
        assert(kind in ["MNIST", "Fashion"])
        self.train = train
        if kind == "MNIST":
            self.data_dir="/share/datasets/mnist/data/mnist/"
        else:
            self.data_dir="/share/datasets/mnist/data/fashion/"

        if self.train:
            self.files = {
                "images": "train-images-idx3-ubyte.gz",
                "labels": "train-labels-idx1-ubyte.gz"
            }
        else:
            self.files = {
                "images": "t10k-images-idx3-ubyte.gz",
                "labels": "t10k-labels-idx1-ubyte.gz"
            }

        self.init_images_labels()

        self.augmentor = transforms.Compose([transforms.RandomHorizontalFlip(),
                                             transforms.RandomRotation(degrees=30),
                                             transforms.RandomVerticalFlip()])
        self.loader = transforms.Compose([transforms.ToTensor()])


    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        image = self.X[idx]
        label = self.y[idx]

        image = self._load_image(image)

        label = np.expand_dims(label, axis=0)
        label = torch.from_numpy(np.array(label))

        return image, label

    def _load_image(self, image):
        image = Image.fromarray(image)
        # Gray scale to RGB
        # image = image.convert('RGB')
        if self.train:
            image = self.augmentor(image)
        image = self.loader(image).float()
        return image

    def init_images_labels(self):
        self.y = self._read_labels_from_binary(os.path.join(self.data_dir, self.files["labels"]))
        self.X = self._read_images_from_binary(os.path.join(self.data_dir, self.files["images"]))

    def _read_labels_from_binary(self, filepath):
        with gzip.open(filepath, mode="rb") as f:
            data = f.read()
            array = np.fromstring(data, dtype='<u1')
            # first 8 bytes are metadata
            array = array[8:]

        return array

    def _read_images_from_binary(self, filepath):
        with gzip.open(filepath, mode="rb") as f:
            data = f.read()
            array = np.fromstring(data, dtype='<u1')
            # first 16 bytes are metadata
            array = array[16:]
            len_array = len(array) // 28 // 28
            array = array.reshape((len_array, 28, 28))

        return array

def main():
    mnist = MNIST(kind="Fashion")

if __name__ == "__main__":
    main()

多くはPyTorchに共通する処理なので、説明を割愛し、MNISTに特化した部分について説明する。

init

MNISTとFashion-MNISTはひとつのデータは同じ大きさになっているし、データの数も全く一緒である。 したがってファイルパスの設定を変えれば、MNISTとFashion-MNISTをスイッチできる。 それが冒頭の__init__()の中で定義されたkindの役割である。

_read_labels_from_binary

ラベルをバイナリから読み取るために、ファイルをgzipライブラリで解凍し、numpyのfromstringで配列にする。 読み込み元はバイナリなのだが、まぁそれは置いておいて、<u1で1バイトごとに符号なし整数として読み込んでいる。 冒頭の8バイトはこのデータ全体の説明になっている。詳しくは MNISTの公式ページにある「TRAINING SET LABEL FILE (train-labels-idx1-ubyte)」を読むこと。 データ数等が書いてあるが、正直要らないので、読み捨てにする。

_read_images_from_binary

ラベルとほぼ同じ処理を行ったうえで、冒頭の16バイトをラベルと同じ理由で捨てる。 そしてひとつのデータは28x28の長さの1次元のベクトルなのだが、後で画像として扱うためにreshapeして2次元のベクトルに直している。 reshapeの処理が不安だったので、書いている途中でPIL.Image.fromarray(image).show()として表示される画像をみて確認はしておいた。

以上、何か問題など見つけたら教えてください。