支援陣列 API 標準#

注意

陣列 API 標準支援仍處於實驗階段,並隱藏在環境變數之後。目前僅涵蓋公共 API 的一小部分。

本指南說明如何使用新增對 Python 陣列 API 標準 的支援。此標準允許使用者直接將任何與陣列 API 相容的陣列函式庫與 SciPy 的部分功能一起使用。

RFC 定義了 SciPy 如何實作對該標準的支援,主要原則是「輸入陣列類型等於輸出陣列類型」。此外,實作還對允許的類陣列輸入進行更嚴格的驗證,例如拒絕 numpy 矩陣和遮罩陣列實例,以及具有 object dtype 的陣列。

在下文中,與陣列 API 相容的命名空間以 xp 表示。

使用陣列 API 標準支援#

若要啟用陣列 API 標準支援,必須在匯入 SciPy 之前設定環境變數

export SCIPY_ARRAY_API=1

這既啟用陣列 API 標準支援,也為類陣列引數啟用更嚴格的輸入驗證。請注意,此環境變數旨在作為臨時措施,以便進行漸進式變更並將其合併到 ``main`` 中,而不會立即影響向後相容性。我們不打算長期保留此環境變數。

此叢集範例展示了以 PyTorch 張量作為輸入和傳回值的用法

>>> import torch
>>> from scipy.cluster.vq import vq
>>> code_book = torch.tensor([[1., 1., 1.],
...                           [2., 2., 2.]])
>>> features  = torch.tensor([[1.9, 2.3, 1.7],
...                           [1.5, 2.5, 2.2],
...                           [0.8, 0.6, 1.7]])
>>> code, dist = vq(features, code_book)
>>> code
tensor([1, 1, 0], dtype=torch.int32)
>>> dist
tensor([0.4359, 0.7348, 0.8307])

請注意,以上範例適用於 PyTorch CPU 張量。對於 GPU 張量或 CuPy 陣列,vq 的預期結果是 TypeError,因為 vq 在其實作中使用編譯後的程式碼,這在 GPU 上無法運作。

更嚴格的陣列輸入驗證將拒絕 np.matrixnp.ma.MaskedArray 實例,以及具有 object dtype 的陣列

>>> import numpy as np
>>> from scipy.cluster.vq import vq
>>> code_book = np.array([[1., 1., 1.],
...                       [2., 2., 2.]])
>>> features  = np.array([[1.9, 2.3, 1.7],
...                       [1.5, 2.5, 2.2],
...                       [0.8, 0.6, 1.7]])
>>> vq(features, code_book)
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))

>>> # The above uses numpy arrays; trying to use np.matrix instances or object
>>> # arrays instead will yield an exception with `SCIPY_ARRAY_API=1`:
>>> vq(np.asmatrix(features), code_book)
...
TypeError: 'numpy.matrix' are not supported

>>> vq(np.ma.asarray(features), code_book)
...
TypeError: 'numpy.ma.MaskedArray' are not supported

>>> vq(features.astype(np.object_), code_book)
...
TypeError: object arrays are not supported

目前支援的功能#

當設定環境變數時,以下模組提供陣列 API 標準支援

scipy.special 中的以下函式提供支援: scipy.special.log_ndtrscipy.special.ndtrscipy.special.ndtriscipy.special.erfscipy.special.erfcscipy.special.i0scipy.special.i0escipy.special.i1scipy.special.i1escipy.special.gammalnscipy.special.gammaincscipy.special.gammainccscipy.special.logitscipy.special.expitscipy.special.entrscipy.special.rel_entrscipy.special.rel_entrscipy.special.xlogyscipy.special.chdtrc

scipy.stats 中的以下函式提供支援: scipy.stats.describescipy.stats.momentscipy.stats.skewscipy.stats.kurtosisscipy.stats.kstatscipy.stats.kstatvarscipy.stats.circmeanscipy.stats.circvarscipy.stats.circstdscipy.stats.entropyscipy.stats.variationscipy.stats.semscipy.stats.ttest_1sampscipy.stats.pearsonrscipy.stats.chisquarescipy.stats.skewtestscipy.stats.kurtosistestscipy.stats.normaltestscipy.stats.jarque_berascipy.stats.bartlettscipy.stats.power_divergencescipy.stats.monte_carlo_test

請參閱 追蹤問題 以取得更新。

實作筆記#

對陣列 API 標準的支援以及 Numpy、CuPy 和 PyTorch 的特定相容性函式的關鍵部分,是透過 array-api-compat 提供的。此套件透過 git submodule(在 scipy/_lib 下)包含在 SciPy 程式碼庫中,因此不會引入新的依賴項。

array-api-compat 提供通用實用函式,並新增別名,例如 xp.concat(在 NumPy 2.0 中新增 np.concat 之前,對於 numpy,這會對應到 np.concatenate)。這允許在 NumPy、PyTorch、CuPy 和 JAX 之間使用統一的 API(其他函式庫,例如 Dask,正在開發中)。

當未設定環境變數,且 SciPy 中的陣列 API 標準支援已停用時,我們仍然使用 NumPy 命名空間的包裝版本,即 array_api_compat.numpy。這不應變更 SciPy 函式的行為,因為它實際上是現有的 numpy 命名空間,新增了一些別名,並為陣列 API 標準支援修改/新增了一些函式。啟用支援後,xp = array_namespace(input) 將是與標準相容的命名空間,將輸入陣列類型與函式匹配(例如,如果 cluster.vq.kmeans 的輸入是 PyTorch 張量,則 xparray_api_compat.torch)。

將陣列 API 標準支援新增至 SciPy 函式#

在可能的情況下,新增至 SciPy 的新程式碼應盡可能遵循陣列 API 標準(這些函式通常也是 NumPy 用法的最佳實務慣例)。透過遵循標準,有效新增對陣列 API 標準的支援通常很簡單,理想情況下,我們不需要維護任何自訂項目。

scipy._lib._array_api 中提供了各種輔助函式 - 請參閱該模組中的 __all__ 以取得目前輔助函式的列表,並參閱其 docstring 以取得更多資訊。

若要為在 .py 檔案中定義的 SciPy 函式新增支援,您必須變更的是

  1. 輸入陣列驗證,

  2. 使用 xp 而非 np 函式,

  3. 在呼叫編譯後的程式碼時,先將陣列轉換為 NumPy 陣列,然後再轉換回輸入陣列類型。

輸入陣列驗證使用以下模式

xp = array_namespace(arr) # where arr is the input array
# alternatively, if there are multiple array inputs, include them all:
xp = array_namespace(arr1, arr2)

# replace np.asarray with xp.asarray
arr = xp.asarray(arr)
# uses of non-standard parameters of np.asarray can be replaced with _asarray
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)

請注意,如果一個輸入是非 NumPy 陣列類型,則所有類陣列輸入都必須是該類型;嘗試將非 NumPy 陣列與列表、Python 純量或其他任意 Python 物件混合將引發例外。對於 NumPy 陣列,出於向後相容性的原因,這些類型將繼續被接受。

如果函式僅呼叫編譯後的程式碼一次,請使用以下模式

x = np.asarray(x)  # convert to numpy right before compiled call(s)
y = _call_compiled_code(x)
y = xp.asarray(y)  # convert back to original array type

如果多次呼叫編譯後的程式碼,請確保僅執行一次轉換,以避免過多的額外負荷。

以下是假設的公共 SciPy 函式 toto 的範例

def toto(a, b):
    a = np.asarray(a)
    b = np.asarray(b, copy=True)

    c = np.sum(a) - np.prod(b)

    # this is some C or Cython call
    d = cdist(c)

    return d

您將像這樣轉換它

def toto(a, b):
    xp = array_namespace(a, b)
    a = xp.asarray(a)
    b = xp_copy(b, xp=xp)  # our custom helper is needed for copy

    c = xp.sum(a) - xp.prod(b)

    # this is some C or Cython call
    c = np.asarray(c)
    d = cdist(c)
    d = xp.asarray(d)

    return d

通過編譯後的程式碼需要返回 NumPy 陣列,因為 SciPy 的擴充模組僅適用於 NumPy 陣列(或 Cython 情況下的 memoryview)。對於 CPU 上的陣列,轉換應為零複製,而在 GPU 和其他裝置上,轉換嘗試將引發例外。原因是裝置之間靜默資料傳輸被認為是不良實務,因為它很可能是一個大型且難以偵測的效能瓶頸。

新增測試#

以下 pytest 標記可用

  • array_api_compatible -> xp:使用參數化在多個陣列後端上執行測試。

  • skip_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, exceptions=None):跳過某些後端或後端類別。@pytest.mark.usefixtures("skip_xp_backends") 必須與此標記一起使用才能套用跳過。請參閱 scipy.conftest 中 fixture 的 docstring,以取得有關如何使用此標記來跳過測試的資訊。

  • xfail_xp_backends(backend=None, reason=None, np_only=False, cpu_only=False, exceptions=None):xfail 某些後端或後端類別。@pytest.mark.usefixtures("xfail_xp_backends") 必須與此標記一起使用才能套用 xfail。請參閱 scipy.conftest 中 fixture 的 docstring,以取得有關如何使用此標記來 xfail 測試的資訊。

  • skip_xp_invalid_arg 用於跳過在啟用 SCIPY_ARRAY_API 時使用無效引數的測試。例如,scipy.stats 函式的一些測試將遮罩陣列傳遞給正在測試的函式,但遮罩陣列與陣列 API 不相容。skip_xp_invalid_arg 裝飾器的使用允許這些測試在未使用 SCIPY_ARRAY_API 時防止回歸,而不會在使用 SCIPY_ARRAY_API 時導致失敗。隨著時間的推移,我們希望這些函式在收到無效的陣列 API 輸入時發出棄用警告,並且此裝飾器將檢查是否發出了棄用警告,而不會導致測試失敗。當 SCIPY_ARRAY_API=1 行為成為預設且唯一的行為時,將移除這些測試(以及裝飾器本身)。

scipy._lib._array_api 包含與陣列無關的斷言,例如 xp_assert_close,可用於取代來自 numpy.testing 的斷言。

以下範例示範如何使用標記

from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
from scipy._lib._array_api import xp_assert_close
...
@pytest.mark.skip_xp_backends(np_only=True, reason='skip reason')
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto1(self, xp):
    a = xp.asarray([1, 2, 3])
    b = xp.asarray([0, 2, 5])
    xp_assert_close(toto(a, b), a)
...
@pytest.mark.skip_xp_backends('array_api_strict',
                              reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy',
                              reason='skip reason 2')
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto2(self, xp):
    ...
...
# Do not run when SCIPY_ARRAY_API is used
@skip_xp_invalid_arg
def test_toto_masked_array(self):
    ...

cpu_only=True 時,將自訂原因傳遞給 reason 不受支援,因為 cpu_only=True 可以與傳遞 backends 一起使用。此外,使用 cpu_only 的原因可能只是在正在測試的函式中使用了編譯後的程式碼。

將後端的名稱傳遞到 exceptions 表示它們不會被 cpu_only=True 跳過。當委派針對某些但非全部非 CPU 後端實作時,這非常有用,並且 CPU 程式碼路徑需要轉換為 NumPy 以用於編譯後的程式碼

# array-api-strict and CuPy will always be skipped, for the given reasons.
# All libraries using a non-CPU device will also be skipped, apart from
# JAX, for which delegation is implemented (hence non-CPU execution is supported).
@pytest.mark.skip_xp_backends(cpu_only, exceptions=['jax.numpy'])
@pytest.mark.skip_xp_backends('array_api_strict', reason='skip reason 1')
@pytest.mark.skip_xp_backends('cupy', reason='skip reason 2')
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto(self, xp):
    ...

當檔案中的每個測試函式都已更新以實現陣列 API 相容性時,可以透過告知 pytest 使用 pytestmark 將標記套用至每個測試函式來降低冗長性

from scipy.conftest import array_api_compatible

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends
...
@skip_xp_backends(np_only=True, reason='skip reason')
def test_toto1(self, xp):
    ...

套用這些標記後,dev.py test 可以與新的選項 -b--array-api-backend 一起使用

python dev.py test -b numpy -b torch -s cluster

這會自動適當地設定 SCIPY_ARRAY_API。若要測試具有多個裝置且使用非預設裝置的函式庫,可以設定第二個環境變數(SCIPY_DEVICE,僅在測試套件中使用)。有效值取決於正在測試的陣列函式庫,例如,對於 PyTorch,有效值為 "cpu"、"cuda"、"mps"。若要使用 PyTorch MPS 後端執行測試套件,請使用:SCIPY_DEVICE=mps python dev.py test -b torch

請注意,有一個 GitHub Actions 工作流程使用 array-api-strict、PyTorch 和 JAX 在 CPU 上進行測試。

其他資訊#

以下是一些其他資源,這些資源激發了一些設計決策,並在開發階段提供了幫助

  • 包含一些討論的初始 PR

  • 從此 PR 快速開始,並從 scikit-learn 中獲得一些啟發。

  • PR 為 scikit-learn 新增陣列 API 支援

  • 其他一些相關的 scikit-learn PR:#22554#25956