Python Tips: スーパークラスのメソッドをオーバーライドできているか確認したい

今回は スーパークラスのメソッドをオーバーライドをできているか確認する方法 についてです。

Python には言語そのものの機能として「定義したメソッドが祖先クラスのメソッドをオーバーライドしていること」を確約する方法が(私が知るかぎり)ありません。

そのため、他人が作ったクラス(ライブラリやフレームワークが提供するクラス)を継承してメソッドをオーバーライドするときは、メソッド名のタイポに注意が必要です。オーバーライドしたつもりでできていないと思わぬバグに悩まされることがあります。

また、ライブラリを長く使っていると、バージョンアップ等によって継承元クラスのメソッドがいつの間にかなくなってしまい、オーバーライドがうまく行かなくなっていたということも起こりえます。

そんなときには、メソッド定義時に「スーパープラスのメソッドを正しくオーバーライドできていること」を確認する仕組みを入れておくと便利です。

具体的には、次のようなデコレータを定義します。

# f-string を使用しているため、動かすには Python 3.6 以上が必要です。

def overrides(klass):
    def check_super(method):
        method_name = method.__name__
        msg = f'`{method_name}()` is not defined in `{klass.__name__}`.'
        assert method_name in dir(klass), msg

    def wrapper(method):
        check_super(method)
        return method

    return wrapper

このデコレータ overrides は次のように使用することができます。

class A:
    def method1(self):
        pass


class B(A):
    @overrides(A)
    def method1(self):
        pass

クラス B はクラス A を継承しています。そして、クラス B のメソッド method1() はクラス Amethod1() をオーバーライドしています。このコードはエラーなく動きます。

一方、正しくオーバーライドができていない次のクラス C の場合は、定義時( compile time )にエラーが発生します。

class A:
    def method1(self):
        pass


class C(A):
    # `method1` をオーバーライドしたつもりがタイポでできていない場合
    @overrides(A)
    def method2(self):
        pass

この C のコードを実行すると次のようなエラーが発生し間違いを早期に知らせてくれます。

Traceback (most recent call last):
  File "xxx.py", line n, in

<module>
    class C(A):
  File "xxx.py", line n, in C
    @overrides(A)
  File "xxx.py", line n, in wrapper
    check_super(method)
  File "xxx.py", line 14, in check_super
    assert method_name in dir(klass), msg
AssertionError: `method2()` is not defined in `A`.
</module>

以上です。

ちなみに、今回のものはデコレータを使ったアプローチでしたが、これと同じことを実現する他の方法としては「メタクラスを使った方法」や「 __init_subclass__() を使った方法」等があるようです。

参考: