MNISTとFashion-MNISTのPyTorchのデータローダ
概要
MNISTとFashion-MINISTを組み合わせて学習してみたかったが、既存のデータセットだと多分無理っぽかったので、自作してみた。(車輪のなんとやらではないという言い訳)
ここでは組み合わせるところまでは書かず、一般的なMNISTのローダをどのように書けばよいかが分かるようなものを記載したいと思う。
もし同じような動機に駆られたら参考にしてほしい。
ちなみに、単純にMNISTをPytorchで読み込むだけだったら、以下でいける。とても簡単。
train = torch.utils.data.DataLoader(MNIST('data', train=False))
準備
MNISTおよびFashion-MNISTのバイナリデータ(gzipで圧縮済み)をダウンロードしてきて、 適当なディレクトリに入れておく。
自分の場合は/share/datasets/mnist/data/
の下に入れておいた。
MNISTは以下のページにダウンロード元が記載されている。
http://yann.lecun.com/exdb/mnist/
Fashion-MNISTは以下のページにダウロード元が記載されている。
https://github.com/zalandoresearch/fashion-mnist
データの中身
データの中身はバイナリになっている。
ラベルは1バイトにつき、1つのラベル(0-10)が入っている。(たしか)
画像の場合は1バイトが1ピクセルに対応して、0-255までの値が入っている。(多分)
ソースコードと解説
import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms import gzip import os import numpy as np from PIL import Image class MNIST(Dataset): def __init__(self, kind="MNIST", train=True): assert(kind in ["MNIST", "Fashion"]) self.train = train if kind == "MNIST": self.data_dir="/share/datasets/mnist/data/mnist/" else: self.data_dir="/share/datasets/mnist/data/fashion/" if self.train: self.files = { "images": "train-images-idx3-ubyte.gz", "labels": "train-labels-idx1-ubyte.gz" } else: self.files = { "images": "t10k-images-idx3-ubyte.gz", "labels": "t10k-labels-idx1-ubyte.gz" } self.init_images_labels() self.augmentor = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomRotation(degrees=30), transforms.RandomVerticalFlip()]) self.loader = transforms.Compose([transforms.ToTensor()]) def __len__(self): return len(self.X) def __getitem__(self, idx): image = self.X[idx] label = self.y[idx] image = self._load_image(image) label = np.expand_dims(label, axis=0) label = torch.from_numpy(np.array(label)) return image, label def _load_image(self, image): image = Image.fromarray(image) # Gray scale to RGB # image = image.convert('RGB') if self.train: image = self.augmentor(image) image = self.loader(image).float() return image def init_images_labels(self): self.y = self._read_labels_from_binary(os.path.join(self.data_dir, self.files["labels"])) self.X = self._read_images_from_binary(os.path.join(self.data_dir, self.files["images"])) def _read_labels_from_binary(self, filepath): with gzip.open(filepath, mode="rb") as f: data = f.read() array = np.fromstring(data, dtype='<u1') # first 8 bytes are metadata array = array[8:] return array def _read_images_from_binary(self, filepath): with gzip.open(filepath, mode="rb") as f: data = f.read() array = np.fromstring(data, dtype='<u1') # first 16 bytes are metadata array = array[16:] len_array = len(array) // 28 // 28 array = array.reshape((len_array, 28, 28)) return array def main(): mnist = MNIST(kind="Fashion") if __name__ == "__main__": main()
多くはPyTorchに共通する処理なので、説明を割愛し、MNISTに特化した部分について説明する。
init
MNISTとFashion-MNISTはひとつのデータは同じ大きさになっているし、データの数も全く一緒である。
したがってファイルパスの設定を変えれば、MNISTとFashion-MNISTをスイッチできる。
それが冒頭の__init__()
の中で定義されたkind
の役割である。
_read_labels_from_binary
ラベルをバイナリから読み取るために、ファイルをgzipライブラリで解凍し、numpyのfromstring
で配列にする。
読み込み元はバイナリなのだが、まぁそれは置いておいて、<u1
で1バイトごとに符号なし整数として読み込んでいる。
冒頭の8バイトはこのデータ全体の説明になっている。詳しくは MNISTの公式ページにある「TRAINING SET LABEL FILE (train-labels-idx1-ubyte)」を読むこと。
データ数等が書いてあるが、正直要らないので、読み捨てにする。
_read_images_from_binary
ラベルとほぼ同じ処理を行ったうえで、冒頭の16バイトをラベルと同じ理由で捨てる。
そしてひとつのデータは28x28の長さの1次元のベクトルなのだが、後で画像として扱うためにreshapeして2次元のベクトルに直している。
reshapeの処理が不安だったので、書いている途中でPIL.Image.fromarray(image).show()
として表示される画像をみて確認はしておいた。
以上、何か問題など見つけたら教えてください。