scikit-learnはどのようにモデルがfit済みかを確認しているか?~check_is_fittedをおちょくる暇人~
以前投稿したDecisionTreeClassifier
の記事で,モデルがfit済みかを確認するcheck_is_fitted
関数について少し触れました。
tatamiya-practice.hatenablog.com
この関数は,決定木に限らずscikit-learnのあらゆる実装で使われています。
そこ本記事では,check_is_fitted
関数についてみていきたいと思います。
Take Home Message
check_is_fitted
関数は,以下をもとにモデルがfit済みかを判定する:
- クラスのインスタンスである
fit
メソッドを持っている- 末尾or文頭が
_
であるattributeを持っている- ただし,
__
で始まるものは除外
- ただし,
そうだ,scikit-learnの決定木を写経しよう…!
はじめに
ブログタイトルの通り,scikit-learnの決定木の実装を読みながら,フルスクラッチで写経していこうと思います。
とはいっても,決定木の実装,めちゃくちゃムズイです。 主な理由としては,アルゴリズムのコアとなる部分の大半がCythonで実装されていることが挙げられます。 他にも,サブモジュール類が複数に分かれていて,従属関係を把握しづらい,という面もあります。
ということで,いきなりゼロからガリガリ書いたるで!っというやり方はしません。 代わりに,最初のうちはサブモジュール・基底クラスなどのかたまりはコピペ・importして使い,インターフェイス部分から徐々に掘り下げていく方針を取ろうと思います。
ということでまずは,分類木 DecisionTreeClassifierから見ていきます。
予測モデルを作る時って,だいたいいつもこんな感じでやりますよね?
from sklearn.tree import DecisionTreeClassifier from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report iris = datasets.load_iris() X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y) clf = DecisionTreeClassifier() clf.fit(X_train, y_train) y_pred = clf.predict(X_test)
ではいきなりですが,scikit-learnのGitHubからDecisionTreeClassifier
を拾って来て,動かしてみましょう!