NumPy 編寫自定義數(shù)組容器

2021-09-01 13:47 更新

在 numpy v1.16 版本中引入的 Numpy 調(diào)度機(jī)制是編寫與 numpy API 兼容并提供 numpy 功能的自定義實現(xiàn)的自定義 N 維數(shù)組容器的推薦方法。應(yīng)用程序包括dask數(shù)組(分布在多個節(jié)點上的 N 維數(shù)組)和cupy數(shù)組(GPU 上的 N 維數(shù)組)。

為了感受如何編寫自定義數(shù)組容器,我們將從一個簡單的示例開始,該示例具有相當(dāng)狹窄的實用性,但說明了所涉及的概念。

  1. >>> import numpy as np
  2. >>> class DiagonalArray:
  3. ... def __init__(self, N, value):
  4. ... self._N = N
  5. ... self._i = value
  6. ... def __repr__(self):
  7. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  8. ... def __array__(self, dtype=None):
  9. ... return self._i * np.eye(self._N, dtype=dtype)

我們的自定義數(shù)組可以像這樣實例化:

  1. >>> arr = DiagonalArray(5, 1)
  2. >>> arr
  3. DiagonalArray(N=5, value=1)

我們可以使用numpy.arrayor?轉(zhuǎn)換成一個 numpy 數(shù)組numpy.asarray,它會調(diào)用它的__array__方法來獲取一個標(biāo)準(zhǔn)的numpy.ndarray.

  1. >>> np.asarray(arr)
  2. array([[1., 0., 0., 0., 0.],
  3. [0., 1., 0., 0., 0.],
  4. [0., 0., 1., 0., 0.],
  5. [0., 0., 0., 1., 0.],
  6. [0., 0., 0., 0., 1.]])

如果我們使用arrnumpy 函數(shù)進(jìn)行操作,numpy 將再次使用該?__array__接口將其轉(zhuǎn)換為數(shù)組,然后以通常的方式應(yīng)用該函數(shù)。

  1. >>> np.multiply(arr, 2)
  2. array([[2., 0., 0., 0., 0.],
  3. [0., 2., 0., 0., 0.],
  4. [0., 0., 2., 0., 0.],
  5. [0., 0., 0., 2., 0.],
  6. [0., 0., 0., 0., 2.]])

請注意,返回類型是標(biāo)準(zhǔn)的numpy.ndarray.

  1. >>> type(np.multiply(arr, 2))
  2. numpy.ndarray

我們?nèi)绾瓮ㄟ^這個函數(shù)傳遞我們的自定義數(shù)組類型?Numpy 允許一個類通過接口__array_ufunc____array_function__.?讓我們一次一個,從_array_ufunc__.?此方法涵蓋?通用函數(shù) (ufunc),這是一類函數(shù),例如包括?numpy.multiplynumpy.sin。

__array_ufunc__接收:

  • ufunc,函數(shù)如?numpy.multiply
  • method, 一個字符串,區(qū)分numpy.multiply(...)和 變體,如numpy.multiply.outer、numpy.multiply.accumulate等。對于常見情況,numpy.multiply(...),?。method?==?'__call__'
  • inputs,這可能是不同類型的混合
  • kwargs, 傳遞給函數(shù)的關(guān)鍵字參數(shù)

對于這個例子,我們將只處理方法?__call__

  1. >>> from numbers import Number
  2. >>> class DiagonalArray:
  3. ... def __init__(self, N, value):
  4. ... self._N = N
  5. ... self._i = value
  6. ... def __repr__(self):
  7. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  8. ... def __array__(self, dtype=None):
  9. ... return self._i * np.eye(self._N, dtype=dtype)
  10. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  11. ... if method == '__call__':
  12. ... N = None
  13. ... scalars = []
  14. ... for input in inputs:
  15. ... if isinstance(input, Number):
  16. ... scalars.append(input)
  17. ... elif isinstance(input, self.__class__):
  18. ... scalars.append(input._i)
  19. ... if N is not None:
  20. ... if N != self._N:
  21. ... raise TypeError("inconsistent sizes")
  22. ... else:
  23. ... N = self._N
  24. ... else:
  25. ... return NotImplemented
  26. ... return self.__class__(N, ufunc(*scalars, **kwargs))
  27. ... else:
  28. ... return NotImplemented

現(xiàn)在我們的自定義數(shù)組類型通過 numpy 函數(shù)。

  1. >>> arr = DiagonalArray(5, 1)
  2. >>> np.multiply(arr, 3)
  3. DiagonalArray(N=5, value=3)
  4. >>> np.add(arr, 3)
  5. DiagonalArray(N=5, value=4)
  6. >>> np.sin(arr)
  7. DiagonalArray(N=5, value=0.8414709848078965)

此時不起作用。arr?+?3

  1. >>> arr + 3
  2. TypeError: unsupported operand type(s) for *: 'DiagonalArray' and 'int'

為了支持它,我們需要定義 Python 接口__add__、__lt__等以分派到相應(yīng)的 ufunc。我們可以通過從 mixin 繼承來方便地實現(xiàn)這一點?NDArrayOperatorsMixin。

  1. >>> import numpy.lib.mixins
  2. >>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
  3. ... def __init__(self, N, value):
  4. ... self._N = N
  5. ... self._i = value
  6. ... def __repr__(self):
  7. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  8. ... def __array__(self, dtype=None):
  9. ... return self._i * np.eye(self._N, dtype=dtype)
  10. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  11. ... if method == '__call__':
  12. ... N = None
  13. ... scalars = []
  14. ... for input in inputs:
  15. ... if isinstance(input, Number):
  16. ... scalars.append(input)
  17. ... elif isinstance(input, self.__class__):
  18. ... scalars.append(input._i)
  19. ... if N is not None:
  20. ... if N != self._N:
  21. ... raise TypeError("inconsistent sizes")
  22. ... else:
  23. ... N = self._N
  24. ... else:
  25. ... return NotImplemented
  26. ... return self.__class__(N, ufunc(*scalars, **kwargs))
  27. ... else:
  28. ... return NotImplemented
  1. >>> arr = DiagonalArray(5, 1)
  2. >>> arr + 3
  3. DiagonalArray(N=5, value=4)
  4. >>> arr > 0
  5. DiagonalArray(N=5, value=True)

現(xiàn)在讓我們解決__array_function__.?我們將創(chuàng)建 dict 將 numpy 函數(shù)映射到我們的自定義變體。

  1. >>> HANDLED_FUNCTIONS = {}
  2. >>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
  3. ... def __init__(self, N, value):
  4. ... self._N = N
  5. ... self._i = value
  6. ... def __repr__(self):
  7. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  8. ... def __array__(self, dtype=None):
  9. ... return self._i * np.eye(self._N, dtype=dtype)
  10. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  11. ... if method == '__call__':
  12. ... N = None
  13. ... scalars = []
  14. ... for input in inputs:
  15. ... # In this case we accept only scalar numbers or DiagonalArrays.
  16. ... if isinstance(input, Number):
  17. ... scalars.append(input)
  18. ... elif isinstance(input, self.__class__):
  19. ... scalars.append(input._i)
  20. ... if N is not None:
  21. ... if N != self._N:
  22. ... raise TypeError("inconsistent sizes")
  23. ... else:
  24. ... N = self._N
  25. ... else:
  26. ... return NotImplemented
  27. ... return self.__class__(N, ufunc(*scalars, **kwargs))
  28. ... else:
  29. ... return NotImplemented
  30. ... def __array_function__(self, func, types, args, kwargs):
  31. ... if func not in HANDLED_FUNCTIONS:
  32. ... return NotImplemented
  33. ... # Note: this allows subclasses that don't override
  34. ... # __array_function__ to handle DiagonalArray objects.
  35. ... if not all(issubclass(t, self.__class__) for t in types):
  36. ... return NotImplemented
  37. ... return HANDLED_FUNCTIONS[func](*args, **kwargs)
  38. ...

一個方便的模式是定義一個implements可用于向HANDLED_FUNCTIONS.

  1. >>> def implements(np_function):
  2. ... "Register an __array_function__ implementation for DiagonalArray objects."
  3. ... def decorator(func):
  4. ... HANDLED_FUNCTIONS[np_function] = func
  5. ... return func
  6. ... return decorator
  7. ...

現(xiàn)在我們編寫 numpy 函數(shù)的實現(xiàn)DiagonalArray。為了完整起見,為了支持用法,arr.sum()添加一個sum調(diào)用的方法,numpy.sum(self)對于mean.

  1. >>> @implements(np.sum)
  2. ... def sum(arr):
  3. ... "Implementation of np.sum for DiagonalArray objects"
  4. ... return arr._i * arr._N
  5. ...
  6. >>> @implements(np.mean)
  7. ... def mean(arr):
  8. ... "Implementation of np.mean for DiagonalArray objects"
  9. ... return arr._i / arr._N
  10. ...
  11. >>> arr = DiagonalArray(5, 1)
  12. >>> np.sum(arr)
  13. 5
  14. >>> np.mean(arr)
  15. 0.2

如果用戶嘗試使用 中未包含的任何 numpy 函數(shù)?HANDLED_FUNCTIONS,TypeError則 numpy 將引發(fā)a?,表示不支持此操作。例如,連接兩個?DiagonalArrays不會產(chǎn)生另一個對角數(shù)組,因此不支持。

  1. >>> np.concatenate([arr, arr])
  2. TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

此外,我們的summean實現(xiàn)不接受 numpy 的實現(xiàn)所做的可選參數(shù)。

  1. >>> np.sum(arr, axis=0)
  2. TypeError: sum() got an unexpected keyword argument 'axis'

用戶總是具有轉(zhuǎn)換為正常的選擇numpy.ndarray與?numpy.asarray和使用標(biāo)準(zhǔn)numpy的從那里。

  1. >>> np.concatenate([np.asarray(arr), np.asarray(arr)])
  2. array([[1., 0., 0., 0., 0.],
  3. [0., 1., 0., 0., 0.],
  4. [0., 0., 1., 0., 0.],
  5. [0., 0., 0., 1., 0.],
  6. [0., 0., 0., 0., 1.],
  7. [1., 0., 0., 0., 0.],
  8. [0., 1., 0., 0., 0.],
  9. [0., 0., 1., 0., 0.],
  10. [0., 0., 0., 1., 0.],
  11. [0., 0., 0., 0., 1.]])

有關(guān)自定義數(shù)組容器的更完整示例,請參閱dask 源代碼和?cupy 源代碼。 另見NEP 18。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號