【Python・scikit-learn】パーセプトロンで2値分類

導入

今回は、入力データを複数のクラスに分類するクラス分類を行います。
特に、最も簡単な2クラスの分類(2値分類)を、パーセプトロンで実装してみましょう。

クラス分類

クラス分類は、入力データを分類する処理のことを言います。
例えば有名なのは、衣服の写真を入力し、その服がシャツかズボンか、、、などの分類ですね。 この「シャツ」「ズボン」といった分類対象のことをクラスと言います。
(衣服の分類は、Fashion-MNISTという有名なサンプルデータがあるため、よく題材にされています。 ) クラス分類の中でも特に、「この画像は犬かどうか」のような、Yes/Noで出力されるような2クラスの分類を、2値分類と言います。

パーセプトロン

パーセプトロンは、最もシンプルなニューラルネットワークと言えます。ニューラルネットワークについての詳細な解説は省きます。
パーセプトロンは、活性化関数などを持たない単純なニューラルネットワークで、線形な分離しか行えないという特徴があります。

実践

まずは、以下のように、scikit-learnの機能で2値分類用のデータを生成しましょう。

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_classification

# 乱数のシードを設定
np.random.seed(3)

#2クラスの分類(2値分類)データを作成
#xには2次元座標、yにはクラスラベル(0か1)が格納される
x,y = make_classification(n_samples=100, 
                          n_features=2, 
                          n_redundant=0, 
                          n_informative=2, 
                          n_clusters_per_class=1, 
                          n_classes=2)

#クラスラベルが0の点を青、クラスラベルが1の点を赤でプロット
plt.scatter(x[y == 0][:, 0], x[y == 0][:, 1], c='blue', label='Class 0')
plt.scatter(x[y == 1][:, 0], x[y == 1][:, 1], c='red', label='Class 1')
plt.legend()
plt.show()

実行結果は以下のようになります。このデータに対して、パーセプトロンで2値分類を行います。

今回分類するデータ

続いて、パーセプトロンの作成と学習です。

from sklearn.linear_model import  Perceptron

#パーセプトロンのインスタンスを生成
perceptron = Perceptron()
#データに対してパーセプトロンをフィッティング
perceptron.fit(x, y)

#学習後の係数と切片を取得
coef = perceptron.coef_[0]
intercept = perceptron.intercept_

#データをプロット
plt.scatter(x[y == 0][:, 0], x[y == 0][:, 1], c='blue', label='Class 0')
plt.scatter(x[y == 1][:, 0], x[y == 1][:, 1], c='red', label='Class 1')

#学習後の境界をプロット
line = np.linspace(-15,15)
plt.plot(line, -(line * coef[0] + intercept) / coef[1], c="g", label="perceptron")
plt.xlim(-2, 4.5)
plt.ylim(-3, 2.5)
plt.legend()

実行結果は以下のようになります。

パーセプトロンによる2値分類の視覚化

このように境界を定めることで、今後新たなデータが得られたときに、そのデータが境界のどちら側にあるかによって、クラスを分類(予測)することができるようになります。
例えば、点(3, 1)という点がどちらのクラスに属するかを予測するには、以下のようなコードで予測できます。

perceptron.predict([[3, 1]])

predict()は、引数のデータが分類されるクラスを返します。 今回の2値分類モデルでは[0]または[1]が返ってきます。点(3,1)に対しては、[0]が結果として返ってきました。

predict()の実行結果