多重繼承與 Mixin

April 29, 2022

Python 可以進行多重繼承,也就是一次繼承兩個父類別的程式碼定義,父類別之間使用逗號作為區隔。

多重繼承

多個父類別繼承下來的方法名稱沒有衝突時,是最單純的情況,例如:

>>> class P1:
...     def mth1(self):
...         print('mth1')
...
>>> class P2:
...     def mth2(self):
...         print('mth2')
...
>>> class S(P1, P2):
...     pass
...
>>> s = S()
>>> s.mth1()
mth1
>>> s.mth2()
mth2
>>>

如果繼承時多個父類別中有相同的方法名稱,就要注意搜尋的順序,基本上是從子類別開始尋找名稱,接著是同一階層父類別由左至右搜尋,再至更上層同一階層父類別由左至右搜尋,直到達到頂層為止。例如:

>>> class P1:
...     def mth(self):
...         print('P1 mth')
...
>>> class P2:
...     def mth(self):
...         print('P2 mth')
...
>>> class S1(P1, P2):
...     pass
...
>>> class S2(P2, P1):
...     pass
...
>>> s1 = S1()
>>> s2 = S2()
>>> s1.mth()
P1 mth
>>> s2.mth()
P2 mth
>>>

在上面的例子中,S1 繼承父類別的順序是 P1P2,而 S2P2P1,因此在尋找 mth 方法時,S1 實例使用的是 P1 繼承而來方法,而 S2 使用的是 P2 繼承而來的方法。

具體來說,一個子類別在尋找指定的屬性或方法名稱時,會依據類別的 __mro__ 屬性的元素順序尋找(MRO 全名是Method Resolution Order),如果想知道直接父類別的話,可以透過類別的 __bases__ 來得知。

>>> S1.__mro__
(<class '__main__.S1'>, <class '__main__.P1'>, <class '__main__.P2'>, <class 'object'>)
>>> S1.__bases__
(<class '__main__.P1'>, <class '__main__.P2'>)
>>> S2.__mro__
(<class '__main__.S2'>, <class '__main__.P2'>, <class '__main__.P1'>, <class 'object'>)
>>> S2.__bases__
(<class '__main__.P2'>, <class '__main__.P1'>)
>>>

__mro__ 屬性的清單,也可以透過類別的 mro 方法來取得:

>>> S1.mro()
[<class '__main__.S1'>, <class '__main__.P1'>, <class '__main__.P2'>, <class 'object'>]
>>> S2.mro()
[<class '__main__.S2'>, <class '__main__.P2'>, <class '__main__.P1'>, <class 'object'>]
>>>

__mro__ 是唯讀屬性;雖然不建議,不過可以改變 __bases__ 來改變直接父類別,從而令 __mro__ 的內容也跟著變動。

Mixin

多重繼承的能力,通常建議只用來實現 Mixin,也就是抽離可重用流程,必要時混入另一個類別。

來考慮一個 Ball 類別,其中定義了一些比較大小的方法:

class Ball:
    def __init__(self, radius):
        self.radius = radius

    def __eq__(self, other):
        return hasattr(other, 'radius') and self.radius == other.radius

    def __gt__(self, other):
        return hasattr(other, 'radius') and self.radius > other.radius

    def __ge__(self, other):
        return self > other or self == other

    def __lt__(self, other):
        return not (self > other and self == other)

    def __le__(self, other):
        return (not self >= other) or self == other

    def __ne__(self, other):
        return not self == other

在上面看到的 __lt____le____eq____ne____gt____ge__ 等方法,定義了物件之間使用 <<===!=>>= 等比較時,應該要有的比較結果。

事實上「比較」這件任務,許多物件都會用的到,仔細觀察以上的程式碼,會發現一些可重用的方法,可以將之抽離出來:

from abc import ABC, abstractmethod

class Ordering(ABC):
    @abstractmethod
    def __eq__(self, other):
        ...

    @abstractmethod
    def __gt__(self, other):
        ...

    def __ge__(self, other):
        return self > other or self == other

    def __lt__(self, other):
        return not (self > other and self == other)

    def __le__(self, other):
        return (not self >= other) or self == other

    def __ne__(self, other):
        return not self == other

Ordering 這樣的類別,是一個抽象基礎類別,不會定義屬性,也不會有 __init__ 定義。

由於實際的物件 == 以及 > 的行為,必須依不同物件而有不同實作,在 Ordering 中不定義,必須由子類別繼承後實作,為了避免開發者在繼承後忘了實作必要的方法,使用了 @abstractmethod 標註。

至於 __ge____lt____le____ne__ 方法,只是從方才的 Ball 類別中抽取出來的可重用實作。

有了 Ordering 類別後,若有物件需要比較的行為,只要繼承 Ordering 並實作 __eq____gt__ 方法。例如,方才的 Ball 類別現在只需如下撰寫:

class Ball(Ordering):
    def __init__(self, radius: int) -> None:
        self.radius = radius

    def __eq__(self, other):
        return hasattr(other, 'radius') and self.radius == other.radius

    def __gt__(self, other):
        return hasattr(other, 'radius') and self.radius > other.radius

b1 = Ball(10)
b2 = Ball(20)

print(b1 > b2)
print(b1 <= b2)
print(b1 == b2)

在繼承了 Ordering 之後,Ball 類別只需要實作 __eq____gt__ 方法,就能具有比較的行為。

由於 Python 可以多重繼承,在必要時,可以同時混入多個類別,針對必要的方法進行實作,就可以擁有多個類別已定義的可重用實作。

rich comparison 方法

如果你想實作方才的比較功能,其實不用自行實現全部的方法,__lt____le____eq____ne____gt____ge__ 等方法,其實是 object 類別就定義了的方法,在定義類別時,若沒有指定父類別,就是繼承 object 類別,也就繼承了 __lt__ 等方法,因此方才的例子,你是重新定義了這組被稱為 rich comparison 的方法。

並不是每個物件,都要定義整組比較方法,然而,若真的需要定義這整組方法的行為,可以使用 functools.total_ordering。例如:

from functools import total_ordering

@total_ordering
class Ball:
    def __init__(self, radius: int) -> None:
        self.radius = radius

    def __eq__(self, other):
        return hasattr(other, 'radius') and self.radius == other.radius

    def __gt__(self, other):
        return hasattr(other, 'radius') and self.radius > other.radius

b1 = Ball(10)
b2 = Ball(20)

print(b1 > b2)
print(b1 <= b2)
print(b1 == b2)

當一個類別被標註了 @total_ordering 時,必須實作 __eq__ 方法,並選擇 __lt____le____gt____ge__ 其中一個方法實作,這樣就可以擁有整組的比較方法了,其背後基本的原理在於,只要定義了 __eq__ 以及 __lt____le____gt____ge__ 其中一個方法,假設是 __gt__ 的話,那麼剩下的 __ne____lt____le____ge__ 就可以各自呼叫這兩個方法來完成比較的行為。

分享到 LinkedIn 分享到 Facebook 分享到 Twitter