そうだ,scikit-learnの決定木を写経しよう…!
はじめに
ブログタイトルの通り,scikit-learnの決定木の実装を読みながら,フルスクラッチで写経していこうと思います。
とはいっても,決定木の実装,めちゃくちゃムズイです。 主な理由としては,アルゴリズムのコアとなる部分の大半がCythonで実装されていることが挙げられます。 他にも,サブモジュール類が複数に分かれていて,従属関係を把握しづらい,という面もあります。
ということで,いきなりゼロからガリガリ書いたるで!っというやり方はしません。 代わりに,最初のうちはサブモジュール・基底クラスなどのかたまりはコピペ・importして使い,インターフェイス部分から徐々に掘り下げていく方針を取ろうと思います。
ということでまずは,分類木 DecisionTreeClassifierから見ていきます。
予測モデルを作る時って,だいたいいつもこんな感じでやりますよね?
from sklearn.tree import DecisionTreeClassifier from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report iris = datasets.load_iris() X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y) clf = DecisionTreeClassifier() clf.fit(X_train, y_train) y_pred = clf.predict(X_test)
ではいきなりですが,scikit-learnのGitHubからDecisionTreeClassifier
を拾って来て,動かしてみましょう!
DecisionTreeClassifierを動かす
DecisionTreeClassifierクラスは,scikit-learn/sklearn/tree/内の,_classes.pyというモジュールの中で定義されています。
とりあえずコピペします。Docstringは長いので省きました。
class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): def __init__(self, criterion="gini", splitter="best", max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0., max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, class_weight=None, presort='deprecated', ccp_alpha=0.0): super().__init__( criterion=criterion, splitter=splitter, max_depth=max_depth, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, min_weight_fraction_leaf=min_weight_fraction_leaf, max_features=max_features, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, random_state=random_state, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, presort=presort, ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): super().fit( X, y, sample_weight=sample_weight, check_input=check_input, X_idx_sorted=X_idx_sorted) return self def predict_proba(self, X, check_input=True): check_is_fitted(self) X = self._validate_X_predict(X, check_input) proba = self.tree_.predict(X) if self.n_outputs_ == 1: proba = proba[:, :self.n_classes_] normalizer = proba.sum(axis=1)[:, np.newaxis] normalizer[normalizer == 0.0] = 1.0 proba /= normalizer return proba else: all_proba = [] for k in range(self.n_outputs_): proba_k = proba[:, k, :self.n_classes_[k]] normalizer = proba_k.sum(axis=1)[:, np.newaxis] normalizer[normalizer == 0.0] = 1.0 proba_k /= normalizer all_proba.append(proba_k) return all_proba def predict_log_proba(self, X): proba = self.predict_proba(X) if self.n_outputs_ == 1: return np.log(proba) else: for k in range(self.n_outputs_): proba[k] = np.log(proba[k]) return proba
あれ?意外と短いじゃん? って思うかもしれませんけど,このままでは動きません。
なぜなら,本体部分の大半が,基底クラスBaseDecisionTree
で定義されているからです。
とりあえずBaseDecisionTree
とClassifierMixin
の二つをimportして,コピペしたDecisionTreeClassifier
を動かしてみましょう。
from sklearn.base import ClassifierMixin from sklearn.tree import BaseDecisionTree class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): # ~~ 省略 ~~ clf = DecisionTreeClassifier() clf.fit(X_train, y_train) y_pred = clf.predict(X_test)
これなら無事実行できます。
以下,動作を一つ一つ辿ってみましょう。
fitの実装
DecisionTreeClassifier
の実装に戻ると,fit
メソッドは以下のように実装されています。
def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): super().fit( X, y, sample_weight=sample_weight, check_input=check_input, X_idx_sorted=X_idx_sorted) return self
かなり短いことがわかります。
理由は簡単,fitの処理は基底クラスBaseDecisionTree
を継承して実行しているためです。
こちらの詳細はおいおい追っていくことにしましょう。
predictはどこ?
DecisionTreeClassifier
のどこを見回しても,predict
メソッドは実装されません。
しかし,上記の実行例ではきちんと動いています。
なぜでしょう?
こちらも基底クラスBaseDecisionTree
を継承しているためです。
BaseDecisionTree
の方の実装を覗いてみると,こちらでpredict
メソッドが定義されていることがわかります。
predict_probaが動かない!?
上記の実行例ではpredict
メソッドで予測クラスを返しましたが,
今度はpredict_proba
メソッドで各クラスごとの予測率を出してみましょう。
y_pred_proba = clf.predict_proba(X_test)
すると今度は実行できません。
エラー表示を読めばわかりますが,check_is_fitted
関数が定義されていないことによります。
この関数は,/sklearn/utilsの中のvalidation.pyで定義されていますので,以下の一文を加えて再実行しましょう:
from sklearn.utils.validation import check_is_fitted
check_is_fitted
関数の挙動については,以下の記事で紹介しています。
tatamiya-practice.hatenablog.com
おわりに
この記事では分類決定木DecisionTreeClassifier
の実装のうち,表面の部分を簡単に辿ってみました。
次回以降ではBaseDecisionTree
クラスに少しずつ深入りして,各メソッドの挙動を一つ一つみていきたいと思います。