scipy.stats.

wasserstein_distance_nd#

scipy.stats.wasserstein_distance_nd(u_values, v_values, u_weights=None, v_weights=None)[原始碼]#

計算兩個 N 維離散分佈之間的 Wasserstein-1 距離。

Wasserstein 距離,也稱為推土機距離或最佳傳輸距離,是兩個機率分佈之間的相似性度量 [1]。在離散情況下,Wasserstein 距離可以理解為將一個分佈轉換為另一個分佈的最佳傳輸計畫的成本。成本計算為移動的機率質量的大小與其移動距離的乘積。簡短而直觀的介紹可以在 [2] 中找到。

在 1.13.0 版本中新增。

參數:
u_values2 維類陣列

來自機率分佈的樣本,或是機率分佈的支撐(所有可能值的集合)。軸 0 上的每個元素都是一個觀察值或可能的值,軸 1 代表分佈的維度;也就是說,每一列都是一個向量觀察值或可能的值。

v_values2 維類陣列

來自第二個分佈的樣本或其支撐。

u_weights, v_weights1 維類陣列,選用

與樣本對應的權重或計數,或與支撐值對應的機率質量。元素總和必須為正數且有限。如果未指定,則每個值都分配相同的權重。

回傳值:
distance浮點數

分佈之間計算出的距離。

另請參閱

wasserstein_distance

計算兩個 1 維離散分佈之間的 Wasserstein-1 距離。

註解

給定兩個機率質量函數 \(u\)\(v\),使用歐幾里得範數的分佈之間的第一個 Wasserstein 距離為

\[l_1 (u, v) = \inf_{\pi \in \Gamma (u, v)} \int \| x-y \|_2 \mathrm{d} \pi (x, y)\]

其中 \(\Gamma (u, v)\)\(\mathbb{R}^n \times \mathbb{R}^n\) 上(機率)分佈的集合,其邊際分佈分別為第一個和第二個因子的 \(u\)\(v\)。對於給定值 \(x\)\(u(x)\) 給出位置 \(x\)\(u\) 的機率,\(v(x)\) 也是如此。

這也稱為最佳傳輸問題或蒙日問題。令有限點集 \(\{x_i\}\)\(\{y_j\}\) 分別表示機率質量函數 \(u\)\(v\) 的支撐集。蒙日問題可以表示如下,

\(\Gamma\) 表示傳輸計畫,\(D\) 表示距離矩陣,以及,

\[\begin{split}x = \text{vec}(\Gamma) \\ c = \text{vec}(D) \\ b = \begin{bmatrix} u\\ v\\ \end{bmatrix}\end{split}\]

\(\text{vec}()\) 函數表示向量化函數,該函數通過垂直堆疊矩陣的列將矩陣轉換為列向量。傳輸計畫 \(\Gamma\) 是一個矩陣 \([\gamma_{ij}]\),其中 \(\gamma_{ij}\) 是一個正值,表示從 \(u(x_i)\) 傳輸到 \(v(y_i)\) 的機率質量的大小。對 \(\Gamma\) 的行求和應給出源分佈 \(u\) : \(\sum_j \gamma_{ij} = u(x_i)\) 對所有 \(i\) 成立,對 \(\Gamma\) 的列求和應給出目標分佈 \(v\): \(\sum_i \gamma_{ij} = v(y_j)\) 對所有 \(j\) 成立。距離矩陣 \(D\) 是一個矩陣 \([d_{ij}]\),其中 \(d_{ij} = d(x_i, y_j)\)

給定 \(\Gamma\)\(D\)\(b\),蒙日問題可以通過將 \(A x = b\) 作為約束條件,將 \(z = c^T x\) 作為最小化目標(成本總和)轉換為線性規劃問題,其中矩陣 \(A\) 的形式為

\[ \begin{align}\begin{aligned}\begin{array} {rrrr|rrrr|r|rrrr} 1 & 1 & \dots & 1 & 0 & 0 & \dots & 0 & \dots & 0 & 0 & \dots & 0 \cr 0 & 0 & \dots & 0 & 1 & 1 & \dots & 1 & \dots & 0 & 0 &\dots & 0 \cr \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \cr 0 & 0 & \dots & 0 & 0 & 0 & \dots & 0 & \dots & 1 & 1 & \dots & 1 \cr \hline\\ 1 & 0 & \dots & 0 & 1 & 0 & \dots & \dots & \dots & 1 & 0 & \dots & 0 \cr 0 & 1 & \dots & 0 & 0 & 1 & \dots & \dots & \dots & 0 & 1 & \dots & 0 \cr \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \cr 0 & 0 & \dots & 1 & 0 & 0 & \dots & 1 & \dots & 0 & 0 & \dots & 1 \end{array}\end{aligned}\end{align} \]

通過求解上述線性規劃問題的對偶形式(解為 \(y^*\)),Wasserstein 距離 \(l_1 (u, v)\) 可以計算為 \(b^T y^*\)

上述解決方案的靈感來自 Vincent Herrmann 的部落格 [3]。如需更詳盡的解釋,請參閱 [4]

輸入分佈可以是經驗性的,因此來自樣本,樣本的值實際上是函數的輸入,或者它們可以被視為廣義函數,在這種情況下,它們是位於指定值的狄拉克 delta 函數的加權和。

參考文獻

[2]

Lili Weng, “What is Wasserstein distance?”, Lil’log, https://lilianweng.github.io/posts/2017-08-20-gan/#what-is-wasserstein-distance.

[3]

Hermann, Vincent. “Wasserstein GAN and the Kantorovich-Rubinstein Duality”. https://vincentherrmann.github.io/blog/wasserstein/.

[4]

Peyré, Gabriel, 和 Marco Cuturi。“Computational optimal transport.” Center for Research in Economics and Statistics Working Papers 2017-86 (2017).

範例

計算兩個三維樣本(每個樣本有兩個觀察值)之間的 Wasserstein 距離。

>>> from scipy.stats import wasserstein_distance_nd
>>> wasserstein_distance_nd([[0, 2, 3], [1, 2, 5]], [[3, 2, 3], [4, 2, 5]])
3.0

分別計算具有三個和兩個加權觀察值的兩個二維分佈之間的 Wasserstein 距離。

>>> wasserstein_distance_nd([[0, 2.75], [2, 209.3], [0, 0]],
...                      [[0.2, 0.322], [4.5, 25.1808]],
...                      [0.4, 5.2, 0.114], [0.8, 1.5])
174.15840245217169