t-SNEを使ったMNISTデータの次元削減・可視化

Posted on 2021/01/11 in 機械学習 , Updated on: 2021/01/11

はじめに

t-SNE は、高次元データの情報を保持したまま、低次元データへ変換する方法である。元の空間での点同士の近さが、圧縮後の点同士の近さとできるだけ同じになるように次元を圧縮する方法で、近さの計算に Student-t 分布を使用している。
ここでは、scikit-learnTSNE 関数を使って、mnistデータを2次元データへ変換し、可視化する。

インポート

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import fetch_openml
from sklearn.manifold import TSNE

データ準備

手書き文字の画像データ集である mnistデータを準備する。データは、6万個の 0~9 までの手書き文字画像 (28x28 = 784pixel) があるが、今回はこのうち 1万個をランダムに抽出する。

In [2]:
mnist = fetch_openml('mnist_784', version=1)
mnist.target = mnist.target.astype(int)

idx = np.random.permutation(60000)[:10000] # 60000の数字をランダムに置き換えて、10000個抽出してインデックスを作成

X = mnist['data'][idx]
y = mnist['target'][idx]

X.shape
Out[2]:
(10000, 784)

これで、X には、10000個の784ピクセルのデータが、y にはそれに付随する答え(数字)が準備できた。

t-SNE

784ピクセルのデータは、784次元のデータであり、これを t-SNE を使って、2次元データに圧縮する。

In [3]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=41)
X_reduced = tsne.fit_transform(X)

可視化

圧縮処理によって、生成された 2次元データを可視化する。

In [4]:
plt.figure(figsize=(13, 7))
plt.scatter(X_reduced[:, 0], X_reduced[:, 1],
            c=y, cmap='jet',
            s=15, alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()

非常にわかりやすくクラスター化されているのがわかる。t-SNE は、784次元という高次元データから、そのデータ空間から同じラベル(数字)の物を近くに配置するようにして、2次元データへ圧縮している。ここで、茶色(ラベル:9) と緑(ラベル:4) が重なっているように見えるが、これは下記のように人間の目から見ても数字の 4 と 9 が似ているというのは理解できる。

In [5]:
plt.figure(figsize=(10, 6))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(X[y==4][i].reshape(28, 28), cmap=plt.cm.binary)
    plt.axis('off')
    if i == 4: plt.title('Pictures of 4', fontsize=20, color='blue')
plt.show()

plt.figure(figsize=(10, 6))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(X[y==9][i].reshape(28, 28), cmap=plt.cm.binary)
    plt.axis('off')
    if i == 4: plt.title('Pictures of 9', fontsize=20, color='blue')
plt.show()

最後に、各クラスター上にいくつか、手書き画像データを表示したグラフを作成する。

In [6]:
from sklearn.preprocessing import MinMaxScaler
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

plt.figure(figsize=(25, 20))
neighbors = np.array([[10., 10.]])
X_normalized = MinMaxScaler().fit_transform(X_reduced)
for digit in np.unique(y):
    plt.scatter(X_normalized[y==digit, 0], X_normalized[y==digit, 1],
                c=[plt.cm.get_cmap("jet")(digit/9)])

ax = plt.gcf().gca()
for index, image_coord in enumerate(X_normalized):
    closest_distance = np.linalg.norm(np.array(neighbors) - image_coord, axis=1).min()
    if closest_distance > 0.05:
        neighbors = np.r_[neighbors, [image_coord]]
        image = X[index].reshape(28, 28)
        imagebox = AnnotationBbox(OffsetImage(image, cmap="binary"), image_coord)
        ax.add_artist(imagebox)
plt.axis('off')
plt.show()