matplotlib_venn を使ったベン図の作成

Posted on 2020/12/26 in 機械学習 , Updated on: 2020/12/26

はじめに

複数の集合の関係や範囲を視覚的に表す、ベン図を matplotlib_venn を使って表示する方法。

コンペなどでは、Train データと Test データの各特徴量の重なりを確認したいときに非常に便利。

インストール

pip でインストール可能。

In [ ]:
!pip install matplotlib-venn

Venn図作成

ここでは、titanic データを使ってテストする。

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib_venn import venn2
%matplotlib inline

# データを読み込む
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

# データをチェック
train.head()
Out[2]:
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
1 2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 0 PC 17599 71.2833 C85 C
2 3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
3 4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 0 113803 53.1000 C123 S
4 5 0 3 Allen, Mr. William Henry male 35.0 0 0 373450 8.0500 NaN S
In [5]:
# Fare カラムの重なり具合をチェック
column = 'Fare'

plt.figure(figsize=(3,3))
venn2(subsets=(set(train[column].unique()), set(test[column].unique())),
      set_labels=('Train', 'Test'))
plt.title(column)
plt.tight_layout()

ベン図が作成できた。matplotlib_venn は、それぞれの集合に含まれる数字の大小関係から、自動的に円のサイズを調整してくれる。
この図から、Trainデータには、248(112+136)個のユニークな値があり、その内 136個が Test データにも含まれていることがわかる。また、Test データに含まれていて、Train データに含まれないデータが 34個ある。データを確認してみる。

In [24]:
# train/test のユニークな値の数
print('Number of unique Fare value in train :', train['Fare'].nunique())
print('Number of unique Fare value in test  :', test['Fare'].nunique())
Number of unique Fare value in train : 248
Number of unique Fare value in test  : 169
In [25]:
# train と test で共通する値の数
print('Number of common value among train and test :',
      len(set(list(train['Fare'].unique())) & set(list(test['Fare'].unique()))))
Number of common value among train and test : 136

全特徴量のベン図を表示

単純に for 文で回すだけ。

In [23]:
columns = test.columns # train には目的変数が含まれている
columns_num = len(columns)
n_cols = 4
n_rows = columns_num // n_cols + 1

fig, axes = plt.subplots(figsize=(n_cols*3, n_rows*3),
                         ncols=n_cols, nrows=n_rows)

for col, ax in zip(columns, axes.ravel()):
    venn2(
        subsets=(set(train[col].unique()), set(test[col].unique())),
        set_labels=('Train', 'Test'),
        ax=ax
    )
    ax.set_title(col)
    
fig.tight_layout()