畳庵〜tatamiya practice〜

機械学習・統計モデリング練習帳。読書・source code読みの記録や,実装など。

scikit-learnはどのようにモデルがfit済みかを確認しているか?~check_is_fittedをおちょくる暇人~

以前投稿したDecisionTreeClassifierの記事で,モデルがfit済みかを確認するcheck_is_fitted関数について少し触れました。

tatamiya-practice.hatenablog.com

この関数は,決定木に限らずscikit-learnのあらゆる実装で使われています。

そこ本記事では,check_is_fitted関数についてみていきたいと思います。

Take Home Message

check_is_fitted関数は,以下をもとにモデルがfit済みかを判定する:

  • クラスのインスタンスである
  • fitメソッドを持っている
  • 末尾or文頭が_であるattributeを持っている
    • ただし,__で始まるものは除外

実装を辿る

check_is_fitted関数は,/sklearn/utilsの中のvalidation.pyで定義されています。

以下実装です:

def check_is_fitted(estimator, msg=None):

    if isclass(estimator):
        raise TypeError("{} is a class, not an instance.".format(estimator))
    if msg is None:
        msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
               "appropriate arguments before using this estimator.")

    if not hasattr(estimator, 'fit'):
        raise TypeError("%s is not an estimator instance." % (estimator))

    attrs = [v for v in vars(estimator)
             if (v.endswith("_") or v.startswith("_"))
             and not v.startswith("__")]

    if not attrs:
        raise NotFittedError(msg % {'name': type(estimator).__name__})

(引用: scikit-learn/validation.py at fa4646749ce47cf4fe8d15575c448948b5625209 · scikit-learn/scikit-learn · GitHub,Docstring省略)

なにやら複雑ですが,fitされていなければエラーを表示するための関数です。

以下の要領で必要なライブラリをimportしたうえで,一つ一つ挙動を確かめていこうと思います。

from inspect import isclass
import numpy as np
from sklearn.exceptions import NotFittedError

入力について

引数をみると,estimator, msgとなっています。

estimatorはその名の通り,予測クラスのインスタンスです。 例えば

clf = DecisionTreeClassifier()

のように分類木クラスのインスタンスを作成していた場合,引数はestimator=clfのようになります。

一方のmsgは,エラー時に表示するメッセージの文字列です。

なにも指定しなければdefault値msg=Noneが入り, 以下のように通常のエラー文が表示されます。

    if msg is None:
        msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
               "appropriate arguments before using this estimator.")

scikit-learn/validation.py at fa4646749ce47cf4fe8d15575c448948b5625209 · scikit-learn/scikit-learn · GitHub

%(name)の部分には後ほど取得するクラス名が入ります。

もしdefaultとは異なる特殊なエラーメッセージを表示させる必要があるときは,msg引数を使えるようです。

インスタンスかどうかの判定

まず最初に,次のような処理が行われています:

    if isclass(estimator):
        raise TypeError("{} is a class, not an instance.".format(estimator))

これは,入力した予測器estimatorが,classかinstanceかを判定します。

実際に,以下のような適当なクラスを定義して試してみます。

class Hoge():
        def __init__(self):
            pass

check_is_fitted(Hoge)
# TypeError: <class '__main__.Hoge'> is a class, not an instance.

hoge = Hoge()
check_is_fitted(hope)
# TypeError: <__main__.Hoge object at 0x12255ef10> is not an estimator instance.

このように,クラスHogeそのものを入力すると,inspect.isclassで判定されたのち弾かれます。

一方で,インスタンスhoge = Hoge()を代入すれば,その次のステップの学習器かどうかを判定する部分に引っかかります。

学習器かどうかの判定

学習器かどうかは,以下の部分で判定されています。

    if not hasattr(estimator, 'fit'):
        raise TypeError("%s is not an estimator instance." % (estimator))

hasattr関数を使って,estimatorfitメソッドを持っているかを判定しています。

fitメソッドがない場合,先ほど見たようなTypeErrorが表示されます。

試しに,空のfitメソッドを実装して判定してみましょう。

class HogeEstimator():
    
    def fit(self):
        pass

hoge_estimator = HogeEstimator()
check_is_fitted(hoge_estimator)
# NotFittedError: This HogeEstimator instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

学習器としては認められましたが,fit済みでないとして弾かれました。

fit済みの判定

では,最終関門,fit済みかどうかはどう判定しているのでしょう?

注目すべき箇所は以下です。

    attrs = [v for v in vars(estimator)
             if (v.endswith("_") or v.startswith("_"))
             and not v.startswith("__")]

    if not attrs:
        raise NotFittedError(msg % {'name': type(estimator).__name__})

まずvarsにより,学習器インスタンスestimatorのもつattributeを取り出しています。

さらにその中から_で始まるor終わり,かつ先頭が__でないものを取り出しています。

もし条件に当てはまるものが存在しなければ,NotFittedErrorが発動します。

では,無理やり_で始まるattributeを作ってみましょう:

class HogeEstimator2():
    
    def __init__(self):
        
        self._fuga = 'fuga'
    
    def fit(self):
        pass

hoge_estimator2 = HogeEstimator2()

vars(hoge_estimator2)
# {'_fuga': 'fuga'}

[v for v in vars(hoge_estimator2)
         if (v.endswith("_") or v.startswith("_"))
         and not v.startswith("__")]
# ['_fuga']

check_is_fitted(hoge_estimator2)

エラーは表示されませんでした。 無事is_check_fitted関数を騙し通し,fit済み学習器になりすますことができたようです。

おわりに

学習器がfit済みかどうかは,最終的には文頭か文末が_となっているattribute名で判定していました。

当初はなにかしらのフラグをクラス変数として持っているのかと思っていましたが,意外にシンプルな手段でした。

具体的にDecisionTreeClassifierを例に,fit前後でどのようなattributeが定義されているのか見てみました:

# Before
vars(DecisionTreeClassifier())
'''
{'criterion': 'gini',
 'splitter': 'best',
 'max_depth': None,
 'min_samples_split': 2,
 'min_samples_leaf': 1,
 'min_weight_fraction_leaf': 0.0,
 'max_features': None,
 'random_state': None,
 'max_leaf_nodes': None,
 'min_impurity_decrease': 0.0,
 'min_impurity_split': None,
 'class_weight': None,
 'presort': 'deprecated',
 'ccp_alpha': 0.0}
'''

# After
clf = DecisionTreeClassifier().fit(X_train, y_train)
vars(clf)
'''
{'criterion': 'gini',
 'splitter': 'best',
 'max_depth': None,
 'min_samples_split': 2,
 'min_samples_leaf': 1,
 'min_weight_fraction_leaf': 0.0,
 'max_features': None,
 'random_state': None,
 'max_leaf_nodes': None,
 'min_impurity_decrease': 0.0,
 'min_impurity_split': None,
 'class_weight': None,
 'presort': 'deprecated',
 'ccp_alpha': 0.0,
 'n_features_': 4,
 'n_outputs_': 1,
 'classes_': array([0, 1, 2]),
 'n_classes_': 3,
 'max_features_': 4,
 'tree_': <sklearn.tree._tree.Tree at 0x120dd8a40>}
'''

確かに,fit後には学習器の実体であるtree_オブジェクトを含め,_が文頭/文末に来るattributeができていることがわかります。

一方のfit前に持っているattributeですが,これらは全てクラス作成時に定義されるハイパーパラメータです。 たしかに_で始まるor終わるハイパーパラメータって,見たことないな,って今更ながら思いました。