畳庵〜tatamiya practice〜

機械学習・統計モデリング練習帳。読書・source code読みの記録や,実装など。

scikit-learnはどのようにモデルがfit済みかを確認しているか?~check_is_fittedをおちょくる暇人~

以前投稿したDecisionTreeClassifierの記事で,モデルがfit済みかを確認するcheck_is_fitted関数について少し触れました。

tatamiya-practice.hatenablog.com

この関数は,決定木に限らずscikit-learnのあらゆる実装で使われています。

そこ本記事では,check_is_fitted関数についてみていきたいと思います。

Take Home Message

check_is_fitted関数は,以下をもとにモデルがfit済みかを判定する:

  • クラスのインスタンスである
  • fitメソッドを持っている
  • 末尾or文頭が_であるattributeを持っている
    • ただし,__で始まるものは除外
続きを読む

そうだ,scikit-learnの決定木を写経しよう…!

はじめに

ブログタイトルの通り,scikit-learnの決定木の実装を読みながら,フルスクラッチで写経していこうと思います。

とはいっても,決定木の実装,めちゃくちゃムズイです。 主な理由としては,アルゴリズムのコアとなる部分の大半がCythonで実装されていることが挙げられます。 他にも,サブモジュール類が複数に分かれていて,従属関係を把握しづらい,という面もあります。

ということで,いきなりゼロからガリガリ書いたるで!っというやり方はしません。 代わりに,最初のうちはサブモジュール・基底クラスなどのかたまりはコピペ・importして使い,インターフェイス部分から徐々に掘り下げていく方針を取ろうと思います。


ということでまずは,分類木 DecisionTreeClassifierから見ていきます。

予測モデルを作る時って,だいたいいつもこんな感じでやりますよね?

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

iris = datasets.load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

ではいきなりですが,scikit-learnのGitHubからDecisionTreeClassifierを拾って来て,動かしてみましょう!

続きを読む