scikit-learnはどのようにモデルがfit済みかを確認しているか?~check_is_fittedをおちょくる暇人~
以前投稿したDecisionTreeClassifier
の記事で,モデルがfit済みかを確認するcheck_is_fitted
関数について少し触れました。
tatamiya-practice.hatenablog.com
この関数は,決定木に限らずscikit-learnのあらゆる実装で使われています。
そこ本記事では,check_is_fitted
関数についてみていきたいと思います。
Take Home Message
check_is_fitted
関数は,以下をもとにモデルがfit済みかを判定する:
- クラスのインスタンスである
fit
メソッドを持っている- 末尾or文頭が
_
であるattributeを持っている- ただし,
__
で始まるものは除外
- ただし,
実装を辿る
check_is_fitted
関数は,/sklearn/utilsの中のvalidation.pyで定義されています。
以下実装です:
def check_is_fitted(estimator, msg=None): if isclass(estimator): raise TypeError("{} is a class, not an instance.".format(estimator)) if msg is None: msg = ("This %(name)s instance is not fitted yet. Call 'fit' with " "appropriate arguments before using this estimator.") if not hasattr(estimator, 'fit'): raise TypeError("%s is not an estimator instance." % (estimator)) attrs = [v for v in vars(estimator) if (v.endswith("_") or v.startswith("_")) and not v.startswith("__")] if not attrs: raise NotFittedError(msg % {'name': type(estimator).__name__})
(引用: scikit-learn/validation.py at fa4646749ce47cf4fe8d15575c448948b5625209 · scikit-learn/scikit-learn · GitHub,Docstring省略)
なにやら複雑ですが,fitされていなければエラーを表示するための関数です。
以下の要領で必要なライブラリをimportしたうえで,一つ一つ挙動を確かめていこうと思います。
from inspect import isclass import numpy as np from sklearn.exceptions import NotFittedError
入力について
引数をみると,estimator
, msg
となっています。
estimator
はその名の通り,予測クラスのインスタンスです。
例えば
clf = DecisionTreeClassifier()
のように分類木クラスのインスタンスを作成していた場合,引数はestimator=clf
のようになります。
一方のmsg
は,エラー時に表示するメッセージの文字列です。
なにも指定しなければdefault値msg=None
が入り,
以下のように通常のエラー文が表示されます。
if msg is None: msg = ("This %(name)s instance is not fitted yet. Call 'fit' with " "appropriate arguments before using this estimator.")
%(name)
の部分には後ほど取得するクラス名が入ります。
もしdefaultとは異なる特殊なエラーメッセージを表示させる必要があるときは,msg
引数を使えるようです。
インスタンスかどうかの判定
まず最初に,次のような処理が行われています:
if isclass(estimator): raise TypeError("{} is a class, not an instance.".format(estimator))
これは,入力した予測器estimator
が,classかinstanceかを判定します。
実際に,以下のような適当なクラスを定義して試してみます。
class Hoge(): def __init__(self): pass check_is_fitted(Hoge) # TypeError: <class '__main__.Hoge'> is a class, not an instance. hoge = Hoge() check_is_fitted(hope) # TypeError: <__main__.Hoge object at 0x12255ef10> is not an estimator instance.
このように,クラスHoge
そのものを入力すると,inspect.isclass
で判定されたのち弾かれます。
一方で,インスタンスhoge = Hoge()
を代入すれば,その次のステップの学習器かどうかを判定する部分に引っかかります。
学習器かどうかの判定
学習器かどうかは,以下の部分で判定されています。
if not hasattr(estimator, 'fit'): raise TypeError("%s is not an estimator instance." % (estimator))
hasattr
関数を使って,estimator
がfit
メソッドを持っているかを判定しています。
fit
メソッドがない場合,先ほど見たようなTypeError
が表示されます。
試しに,空のfit
メソッドを実装して判定してみましょう。
class HogeEstimator(): def fit(self): pass hoge_estimator = HogeEstimator() check_is_fitted(hoge_estimator) # NotFittedError: This HogeEstimator instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.
学習器としては認められましたが,fit済みでないとして弾かれました。
fit済みの判定
では,最終関門,fit済みかどうかはどう判定しているのでしょう?
注目すべき箇所は以下です。
attrs = [v for v in vars(estimator) if (v.endswith("_") or v.startswith("_")) and not v.startswith("__")] if not attrs: raise NotFittedError(msg % {'name': type(estimator).__name__})
まずvars
により,学習器インスタンスestimator
のもつattributeを取り出しています。
さらにその中から_
で始まるor終わり,かつ先頭が__
でないものを取り出しています。
もし条件に当てはまるものが存在しなければ,NotFittedError
が発動します。
では,無理やり_
で始まるattributeを作ってみましょう:
class HogeEstimator2(): def __init__(self): self._fuga = 'fuga' def fit(self): pass hoge_estimator2 = HogeEstimator2() vars(hoge_estimator2) # {'_fuga': 'fuga'} [v for v in vars(hoge_estimator2) if (v.endswith("_") or v.startswith("_")) and not v.startswith("__")] # ['_fuga'] check_is_fitted(hoge_estimator2)
エラーは表示されませんでした。
無事is_check_fitted
関数を騙し通し,fit済み学習器になりすますことができたようです。
おわりに
学習器がfit済みかどうかは,最終的には文頭か文末が_
となっているattribute名で判定していました。
当初はなにかしらのフラグをクラス変数として持っているのかと思っていましたが,意外にシンプルな手段でした。
具体的にDecisionTreeClassifier
を例に,fit前後でどのようなattributeが定義されているのか見てみました:
# Before vars(DecisionTreeClassifier()) ''' {'criterion': 'gini', 'splitter': 'best', 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'min_weight_fraction_leaf': 0.0, 'max_features': None, 'random_state': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_impurity_split': None, 'class_weight': None, 'presort': 'deprecated', 'ccp_alpha': 0.0} ''' # After clf = DecisionTreeClassifier().fit(X_train, y_train) vars(clf) ''' {'criterion': 'gini', 'splitter': 'best', 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 1, 'min_weight_fraction_leaf': 0.0, 'max_features': None, 'random_state': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_impurity_split': None, 'class_weight': None, 'presort': 'deprecated', 'ccp_alpha': 0.0, 'n_features_': 4, 'n_outputs_': 1, 'classes_': array([0, 1, 2]), 'n_classes_': 3, 'max_features_': 4, 'tree_': <sklearn.tree._tree.Tree at 0x120dd8a40>} '''
確かに,fit後には学習器の実体であるtree_
オブジェクトを含め,_
が文頭/文末に来るattributeができていることがわかります。
一方のfit前に持っているattributeですが,これらは全てクラス作成時に定義されるハイパーパラメータです。
たしかに_
で始まるor終わるハイパーパラメータって,見たことないな,って今更ながら思いました。