【Python】scikit-learnで直線の線形回帰

導入

機械学習の基本的な手法を色々試したいと思い、学習の記録として記事にすることにしました。 難しいことはほとんどせず、解説も少なめですが、ついでに学習中の誰かの助けになればと思います。
今回は、scikit-learnを使って、最も簡単な回帰である線形回帰を試していきます。
jupyter notebook(.ipynb)を使用して学習しているので、データの出力にprint()を明記していないことがあります。お手元で動作させる場合は、適宜読み替えていただくか、同じく.ipynbで実行してください。

データの読みこみ

まずは、今回使うデータの読み込みです。
scikit-learnから読み込める、カリフォルニアの住宅価格データを使用します。
(よく使われていたボストンの住宅価格データは、人権的な問題で避けた方が良いと警告が出るようになりましたね。)

import pandas as pd
from sklearn.datasets import fetch_california_housing

#カリフォルニアの住宅価格データ読み込み
housing = fetch_california_housing()
df = pd.DataFrame(housing.data, columns=housing.feature_names)
#表データの上5件を確認
df.head()

以下のようにデータの列に何があるかを確認できます。

head()でデータの冒頭5件を確認

散布図行列の確認

このままでは、どのデータとどのデータに相関がある(予想できそう)か分からないので、散布図行列を確認してみます。
pandasの機能で、各データ同士の散布図をまとめてチェックすることができます。

#scatter_matrixで、各パラメータ同士の散布図を確認
from pandas.plotting import scatter_matrix
fig = pd.plotting.scatter_matrix(df, figsize=(15,15))

結果は以下の通りです。

散布図行列

当然な感じがしますが、部屋数(AveRooms)と、寝室数(AveBedrms)に正の相関がありそうですね。
この2つを使って、線形回帰をしてみましょう。
これらのデータだけをプロットしたものが以下になります。

#dataFrame.plotで、pandasからプロットできる
df.plot(kind="scatter", x="AveRooms", y="AveBedrms")

実践

いきなりやっちゃいましょう。
線形回帰モデルをmodelに代入し、model.fitでデータに対してフィッティングするだけです。

#scikit-learnの線形回帰
import sklearn.linear_model
import matplotlib.pyplot as plt
import numpy as np

#線形回帰モデルの作成
model = sklearn.linear_model.LinearRegression()

#対象データ取り出し
#データ数*1の行列に変換しておく
x = np.c_[df["AveRooms"]]
y = np.c_[df["AveBedrms"]]

#モデルをデータにフィッティング
model.fit(x,y)

#回帰の式を出力
print(f"y = {model.coef_}x + {model.intercept_}")

#データをプロット
plt.scatter(x, y)

#回帰直線をプロット
plt.plot(x, model.predict(x), color="red")

実行結果

【PlatformIO】M5Atom S3 liteの内蔵LEDでLチカ

M5Atom S3 lite

導入

先日、M5Atom S3 liteを購入しました。
M5Atomの後継であるM5Atom S3の、ディスプレイなどがない小型版が、このM5Atom S3 liteです。

・公式ページ
AtomS3 Lite ESP32S3 Dev Kitshop.m5stack.com

・スイッチサイエンス
ATOMS3 Litewww.switch-science.com

M5Stackシリーズは、機種によって微妙にライブラリの使い方が違います。 M5Atom S3 liteは、M5Atom S3と同じっぽいですが、備忘録として、動作した基本的なプログラムを残しておきます。

サンプル

実際に使ったプロジェクトは、こちらのリポジトリに公開しています。

github.com

環境設定

PlatformIOのBoard設定では、「Espressif ESP32-S3-DevKitC-1-N8(8 MB QD, No PSRAM)」を使用しました。
プロジェクトのパスに日本語が入っていると、ビルド時にファームウェアがないみたいなエラーが出ることがあるので、注意しましょう。

AtomS3 lite使用時の設定

コード

実際のコードは以下です。

#include <M5AtomS3.h>
#include <FastLED.h>
#include "M5Unified.h"
#include <M5GFX.h>

void setup() {
    AtomS3.begin(true);  // Init AtomS3Lite.
    AtomS3.dis.setBrightness(100);
}

//ボタンを押すたび、消灯>赤>緑>青を繰り返す
int count = 0;
int color = 0;
void loop() {
    if(AtomS3.BtnA.wasPressed()){
        count = (count+1)%4;
    }
    switch (count % 4)
    {
    case 0:
        color = 0x000000;
        break;
    case 1:
        color = 0xff0000;
        break;
    case 2:
        color = 0x00ff00;
        break;
    case 3:
        color = 0x0000ff;
        break;
    default:
        break;
    }
    AtomS3.dis.drawpix(color);
    AtomS3.update();
    delay(100);
}

詳しく説明

使用しているライブラリですが、基本はM5AtomS3を使用します。 それだけでビルドしようとすると、コンパイル時に、FastLED, M5Unified, M5GFXも追加する必要があるという旨のエラーが出たので、追加しています。 (本当はなぜ必要かをちゃんと知るべきですが、ひとまず動作確認なので飛ばします。)

多くのM5Stack製品では、M5.hogehogeと書くことが多かったように思いますが、M5AtomS3では、AtomS3.hogehogeで各種M5の機能を呼び出しています。

今回使っているのは、以下の4つですね。

#初期化
AtomS3.begin(true);

#内蔵LEDの明るさ設定
AtomS3.dis.setBrightness(100);

#内蔵LEDの色設定
AtomS3.dis.drawpix(color);

#内蔵ボタンへの入力などを更新
AtomS3.update();

動作確認

動きました!

動作の様子

【Unity】VirtualMotionTrackerで手のトラッキングデータを取得する

先にまとめ

VMTで両手の姿勢をUnity上で取得する方法です。以下のスクリプトを任意のGameObjectにアタッチし、Inspectorで、姿勢確認用のTransform変数に任意のオブジェクトを代入し、左右のコントローラのシリアル番号を代入します。シリアル番号は、VMT ManagerのDevice画面で確認できます。

using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

[RequireComponent(typeof(uOSC.uOscClient))]
[RequireComponent(typeof(uOSC.uOscServer))]
public class WatchHands : MonoBehaviour
{
    uOSC.uOscClient client;
    uOSC.uOscServer server;

    [SerializeField] Transform rightHand;
    [SerializeField] Transform leftHand;

    [SerializeField] string rightHandSerial = "";
    [SerializeField] string leftHandSerial = "";
    void Start()
    {
        client = GetComponent<uOSC.uOscClient>();
        server = GetComponent<uOSC.uOscServer>();

        //アドレス固定
        client.address = "127.0.0.1";
        client.port = 39570;
        server.port = 39571;
        //イベント追加
        server.onDataReceived.AddListener(OnDataReceived);
        server.onServerStopped.AddListener(OnServerStopped);
        //購読開始
        Subscribe();
    }

    public void OnDataReceived(uOSC.Message message)
    {
        if (message.address == "/VMT/Out/SubscribedDevice")
        {
            //受け取った情報がどのデバイスのものか、シリアル番号を取得
            string serial = (string)message.values[0];

            //シリアル番号を識別
            //受け取った姿勢情報を反映させるTransformをtargetに代入
            //どちらのコントローラでもなければreturn
            Transform target = null;
            if (serial == rightHandSerial)
                target = rightHand;
            else if (serial == leftHandSerial)
                target = leftHand;
            else
                return;

            //反映
            float x = (float)message.values[1];
            float y = (float)message.values[2];
            float z = (float)message.values[3];
            float qx = (float)message.values[4];
            float qy = (float)message.values[5];
            float qz = (float)message.values[6];
            float qw = (float)message.values[7];
            target.position = new Vector3(x, y, -z);
            target.rotation = new Quaternion(qx, qy, -qz, -qw);
        }
    }

    public void OnServerStopped(int x)
    {
        Unsubscribe();
    }

    private void OnDestroy()
    {
        Unsubscribe();
    }

    void Unsubscribe()
    {
        Debug.Log($"Unsubscribed:{rightHandSerial}");
        client.Send("/VMT/Unsubscribe/Device", rightHandSerial);
        Debug.Log($"Unsubscribed:{leftHandSerial}");
        client.Send("/VMT/Unsubscribe/Device", leftHandSerial);
    }

    void Subscribe()
    {
        Debug.Log($"Subscribed:{rightHandSerial}");
        client.Send("/VMT/Subscribe/Device", rightHandSerial);
        Debug.Log($"Subscribed:{leftHandSerial}");
        client.Send("/VMT/Subscribe/Device", leftHandSerial);
    }
}

詳しく説明

SteamVRのドライバ(OpenVR)とのデータのやり取りを簡単に行える、Virtual Motion Tracker(VMT)というアプリケーションがあります。

gpsnmeajp.github.io

VMTを使う中で、HMD位置の取得は、以下に示す公式Sampleの、WatchHMD.csを使えば簡単にできます。 github.com

しかし、コントローラなどのHMD以外の位置取得で、デバイスのシリアル番号の確認に少し迷ったので、メモとして残しておきます。

WatchHMD.csの21行目では、デバイス情報の購読を要求しています。

client.Send("/VMT/Subscribe/Device", "HMD");

つまり、デバイス情報を定期的に送信してくれるように、OpenVRへ(VMT経由で)要求しています。この時の第2引数で、購読したいデバイスを指定します。

上記のように、HMDの場合は、"HMD"という文字列を渡せばよいです。しかし、HMD以外の、コントローラやトラッカーの場合は、各デバイスのシリアル番号を渡してやる必要があります。

このシリアル番号は、VMT Manager上で確認できます。VMT ManagerのDeviceタブを開き、Reloadを押すと、現在使用しているデバイスの一覧が出ます。このリストから、取得したいデバイスを選択し、Copy to Clipboardでシリアルを取得できます。

本記事の最初に示したコードを再掲します。これをUnity上で使用し、Inspectorに、コピーしたシリアル番号を張り付ければ、SteamVR内でのコントローラ位置が、Unity上のオブジェクトに反映されるはずです。

using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

[RequireComponent(typeof(uOSC.uOscClient))]
[RequireComponent(typeof(uOSC.uOscServer))]
public class WatchHands : MonoBehaviour
{
    uOSC.uOscClient client;
    uOSC.uOscServer server;

    [SerializeField] Transform rightHand;
    [SerializeField] Transform leftHand;

    [SerializeField] string rightHandSerial = "";
    [SerializeField] string leftHandSerial = "";
    void Start()
    {
        client = GetComponent<uOSC.uOscClient>();
        server = GetComponent<uOSC.uOscServer>();

        //アドレス固定
        client.address = "127.0.0.1";
        client.port = 39570;
        server.port = 39571;
        //イベント追加
        server.onDataReceived.AddListener(OnDataReceived);
        server.onServerStopped.AddListener(OnServerStopped);
        //購読開始
        Subscribe();
    }

    public void OnDataReceived(uOSC.Message message)
    {
        if (message.address == "/VMT/Out/SubscribedDevice")
        {
            //受け取った情報がどのデバイスのものか、シリアル番号を取得
            string serial = (string)message.values[0];

            //シリアル番号を識別
            //受け取った姿勢情報を反映させるTransformをtargetに代入
            //どちらのコントローラでもなければreturn
            Transform target = null;
            if (serial == rightHandSerial)
                target = rightHand;
            else if (serial == leftHandSerial)
                target = leftHand;
            else
                return;

            //反映
            float x = (float)message.values[1];
            float y = (float)message.values[2];
            float z = (float)message.values[3];
            float qx = (float)message.values[4];
            float qy = (float)message.values[5];
            float qz = (float)message.values[6];
            float qw = (float)message.values[7];
            target.position = new Vector3(x, y, -z);
            target.rotation = new Quaternion(qx, qy, -qz, -qw);
        }
    }

    public void OnServerStopped(int x)
    {
        Unsubscribe();
    }

    private void OnDestroy()
    {
        Unsubscribe();
    }

    void Unsubscribe()
    {
        Debug.Log($"Unsubscribed:{rightHandSerial}");
        client.Send("/VMT/Unsubscribe/Device", rightHandSerial);
        Debug.Log($"Unsubscribed:{leftHandSerial}");
        client.Send("/VMT/Unsubscribe/Device", leftHandSerial);
    }

    void Subscribe()
    {
        Debug.Log($"Subscribed:{rightHandSerial}");
        client.Send("/VMT/Subscribe/Device", rightHandSerial);
        Debug.Log($"Subscribed:{leftHandSerial}");
        client.Send("/VMT/Subscribe/Device", leftHandSerial);
    }
}