MNISTのデータセットの簡易可視化
MNISTのデータセットの簡易可視化
はじめに
機械学習のサンプルはMNISTの手書き文字認識をやることが多いんだけど、「機械学習する人はMNISTのデータがどういうものかくらい知ってるでしょ?」って感じで、データ構造の説明が無いことが多い。
ってなわけで、機械学習のド素人がChainerのサンプルで使われるMNISTの手書きデータセットがどういうものかよく見てみる。Chainerはインストール済みであること前提。
データ構造
まずはデータをロードしてみる。chainer.datasets.get_mnist()でデータセットが読み込まれる。そのtestをあれこれしてみる。
>>> import chainer
>>> train, test = chainer.datasets.get_mnist()
>>> test
<chainer.datasets.tuple_dataset.TupleDataset object at 0x2aaab303c910>
>>> len(test)
10000
>>> len(test[0])
2
>>> len(test[0][0])
784
>>> test[0][0]
(snip)
0. , 0. , 0. , 0. ], dtype=float32)
>>> test[0][1]
7
上記から以下のようなことがわかる。
testは10000個のサンプルデータからなるTupleDatasetクラスtest[i]はi番目のサンプルデータtest[i][0]はi番目のサンプルの784個のfloat32のデータtest[i][1]は整数で、i番目の手書きサンプルの「正解」を表す
784 = 28**2 なので、28×28のグリッドのグレースケールのデータなんだろう。
可視化
というわけで、これを二値化して可視化してみる。とりあえず0番のデータ決め打ちで。
import chainer
n = 0
train, test = chainer.datasets.get_mnist()
s = 28
print test[n][1]
for i in range(1,28):
for j in range(1,28):
v = test[n][0][j + i*s]
if v > 0.1:
print "*",
else:
print " ",
print "\n"
最初に正解を、次に手書きデータを出力する。結果はこんな感じ。
$ python show_mnist.py
7
* * * * * *
* * * * * * * * * * * * * * * *
* * * * * * * * * * * * * * * *
* * * * * * * *
* * *
* * *
* * * *
* * * *
* * *
* * *
* * *
* * * *
* * *
* * * *
* * * *
* * * *
* * * *
* * * * *
* * * * *
* * *
0番のデータは数字の「7」ですね。世の中の機械学習のサンプルはこれを入力にして、どの数字の可能性が高いかを判定している模様。
A Robot’s Sigh