畳庵〜tatamiya practice〜

機械学習・統計モデリング練習帳。読書・source code読みの記録や,実装など。

そうだ,scikit-learnの決定木を写経しよう…!

はじめに

ブログタイトルの通り,scikit-learnの決定木の実装を読みながら,フルスクラッチで写経していこうと思います。

とはいっても,決定木の実装,めちゃくちゃムズイです。 主な理由としては,アルゴリズムのコアとなる部分の大半がCythonで実装されていることが挙げられます。 他にも,サブモジュール類が複数に分かれていて,従属関係を把握しづらい,という面もあります。

ということで,いきなりゼロからガリガリ書いたるで!っというやり方はしません。 代わりに,最初のうちはサブモジュール・基底クラスなどのかたまりはコピペ・importして使い,インターフェイス部分から徐々に掘り下げていく方針を取ろうと思います。


ということでまずは,分類木 DecisionTreeClassifierから見ていきます。

予測モデルを作る時って,だいたいいつもこんな感じでやりますよね?

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

iris = datasets.load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

ではいきなりですが,scikit-learnのGitHubからDecisionTreeClassifierを拾って来て,動かしてみましょう!

DecisionTreeClassifierを動かす

DecisionTreeClassifierクラスは,scikit-learn/sklearn/tree/内の,_classes.pyというモジュールの中で定義されています。

とりあえずコピペします。Docstringは長いので省きました。

https://github.com/scikit-learn/scikit-learn/blob/fa4646749ce47cf4fe8d15575c448948b5625209/sklearn/tree/_classes.py#L585

class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):

    def __init__(self,
                 criterion="gini",
                 splitter="best",
                 max_depth=None,
                 min_samples_split=2,
                 min_samples_leaf=1,
                 min_weight_fraction_leaf=0.,
                 max_features=None,
                 random_state=None,
                 max_leaf_nodes=None,
                 min_impurity_decrease=0.,
                 min_impurity_split=None,
                 class_weight=None,
                 presort='deprecated',
                 ccp_alpha=0.0):
        super().__init__(
            criterion=criterion,
            splitter=splitter,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            min_weight_fraction_leaf=min_weight_fraction_leaf,
            max_features=max_features,
            max_leaf_nodes=max_leaf_nodes,
            class_weight=class_weight,
            random_state=random_state,
            min_impurity_decrease=min_impurity_decrease,
            min_impurity_split=min_impurity_split,
            presort=presort,
            ccp_alpha=ccp_alpha)

    def fit(self, X, y, sample_weight=None, check_input=True,
            X_idx_sorted=None):

        super().fit(
            X, y,
            sample_weight=sample_weight,
            check_input=check_input,
            X_idx_sorted=X_idx_sorted)
        return self

    def predict_proba(self, X, check_input=True):

        check_is_fitted(self)
        X = self._validate_X_predict(X, check_input)
        proba = self.tree_.predict(X)

        if self.n_outputs_ == 1:
            proba = proba[:, :self.n_classes_]
            normalizer = proba.sum(axis=1)[:, np.newaxis]
            normalizer[normalizer == 0.0] = 1.0
            proba /= normalizer

            return proba

        else:
            all_proba = []

            for k in range(self.n_outputs_):
                proba_k = proba[:, k, :self.n_classes_[k]]
                normalizer = proba_k.sum(axis=1)[:, np.newaxis]
                normalizer[normalizer == 0.0] = 1.0
                proba_k /= normalizer
                all_proba.append(proba_k)

            return all_proba

    def predict_log_proba(self, X):
   
        proba = self.predict_proba(X)

        if self.n_outputs_ == 1:
            return np.log(proba)

        else:
            for k in range(self.n_outputs_):
                proba[k] = np.log(proba[k])

            return proba

あれ?意外と短いじゃん? って思うかもしれませんけど,このままでは動きません。

なぜなら,本体部分の大半が,基底クラスBaseDecisionTreeで定義されているからです。

とりあえずBaseDecisionTreeClassifierMixinの二つをimportして,コピペしたDecisionTreeClassifierを動かしてみましょう。


from sklearn.base import ClassifierMixin
from sklearn.tree import BaseDecisionTree

class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
# ~~ 省略 ~~

clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

これなら無事実行できます。

以下,動作を一つ一つ辿ってみましょう。

fitの実装

DecisionTreeClassifierの実装に戻ると,fitメソッドは以下のように実装されています。

https://github.com/scikit-learn/scikit-learn/blob/fa4646749ce47cf4fe8d15575c448948b5625209/sklearn/tree/_classes.py#L835

    def fit(self, X, y, sample_weight=None, check_input=True,
            X_idx_sorted=None):

        super().fit(
            X, y,
            sample_weight=sample_weight,
            check_input=check_input,
            X_idx_sorted=X_idx_sorted)
        return self

かなり短いことがわかります。 理由は簡単,fitの処理は基底クラスBaseDecisionTreeを継承して実行しているためです。

こちらの詳細はおいおい追っていくことにしましょう。

predictはどこ?

DecisionTreeClassifierのどこを見回しても,predictメソッドは実装されません。

しかし,上記の実行例ではきちんと動いています。

なぜでしょう?

こちらも基底クラスBaseDecisionTreeを継承しているためです。

BaseDecisionTreeの方の実装を覗いてみると,こちらでpredictメソッドが定義されていることがわかります。

predict_probaが動かない!?

上記の実行例ではpredictメソッドで予測クラスを返しましたが, 今度はpredict_probaメソッドで各クラスごとの予測率を出してみましょう。

y_pred_proba = clf.predict_proba(X_test)

すると今度は実行できません。 エラー表示を読めばわかりますが,check_is_fitted関数が定義されていないことによります。

この関数は,/sklearn/utilsの中のvalidation.pyで定義されていますので,以下の一文を加えて再実行しましょう:

from sklearn.utils.validation import check_is_fitted

check_is_fitted関数の挙動については,以下の記事で紹介しています。

tatamiya-practice.hatenablog.com

おわりに

この記事では分類決定木DecisionTreeClassifierの実装のうち,表面の部分を簡単に辿ってみました。

次回以降ではBaseDecisionTreeクラスに少しずつ深入りして,各メソッドの挙動を一つ一つみていきたいと思います。