scipy.special.huber#

scipy.special.huber(delta, r, out=None) = <ufunc 'huber'>#

Huber 損失函數。

\[\begin{split}\text{huber}(\delta, r) = \begin{cases} \infty & \delta < 0 \\ \frac{1}{2}r^2 & 0 \le \delta, | r | \le \delta \\ \delta ( |r| - \frac{1}{2}\delta ) & \text{otherwise} \end{cases}\end{split}\]
參數:
deltandarray

輸入陣列,表示二次損失與線性損失的變換點。

rndarray

輸入陣列,可能表示殘差。

outndarray, optional

函數值的可選輸出陣列

返回:
純量或 ndarray

計算出的 Huber 損失函數值。

參見

pseudo_huber

此函數的平滑近似

筆記

huber 作為穩健統計或機器學習中的損失函數很有用,可以減少離群值的影響,與常見的平方誤差損失相比,幅度高於 delta 的殘差不會被平方 [1]

通常,r 表示殘差,即模型預測值與資料之間的差異。那麼,對於 \(|r|\leq\delta\)huber 類似於平方誤差,而對於 \(|r|>\delta\),則類似於絕對誤差。這樣,Huber 損失通常在模型擬合中對於小殘差實現快速收斂,就像平方誤差損失函數一樣,並且仍然減少離群值的影響 (\(|r|>\delta\)),就像絕對誤差損失一樣。由於 \(\delta\) 是平方誤差和絕對誤差機制之間的截止點,因此必須針對每個問題仔細調整。huber 也是凸函數,使其適用於基於梯度的最佳化。

在 0.15.0 版本中新增。

參考文獻

[1]

Peter Huber. “Robust Estimation of a Location Parameter”, 1964. Annals of Statistics. 53 (1): 73 - 101.

範例

匯入所有必要的模組。

>>> import numpy as np
>>> from scipy.special import huber
>>> import matplotlib.pyplot as plt

計算 delta=1r=2 時的函數值

>>> huber(1., 2.)
1.5

通過為 delta 提供 NumPy 陣列或列表來計算不同 delta 的函數。

>>> huber([1., 3., 5.], 4.)
array([3.5, 7.5, 8. ])

通過為 r 提供 NumPy 陣列或列表來計算不同點的函數。

>>> huber(2., np.array([1., 1.5, 3.]))
array([0.5  , 1.125, 4.   ])

可以通過為 deltar 提供具有相容廣播形狀的陣列來計算不同 deltar 的函數。

>>> r = np.array([1., 2.5, 8., 10.])
>>> deltas = np.array([[1.], [5.], [9.]])
>>> print(r.shape, deltas.shape)
(4,) (3, 1)
>>> huber(deltas, r)
array([[ 0.5  ,  2.   ,  7.5  ,  9.5  ],
       [ 0.5  ,  3.125, 27.5  , 37.5  ],
       [ 0.5  ,  3.125, 32.   , 49.5  ]])

繪製不同 delta 的函數圖。

>>> x = np.linspace(-4, 4, 500)
>>> deltas = [1, 2, 3]
>>> linestyles = ["dashed", "dotted", "dashdot"]
>>> fig, ax = plt.subplots()
>>> combined_plot_parameters = list(zip(deltas, linestyles))
>>> for delta, style in combined_plot_parameters:
...     ax.plot(x, huber(delta, x), label=fr"$\delta={delta}$", ls=style)
>>> ax.legend(loc="upper center")
>>> ax.set_xlabel("$x$")
>>> ax.set_title(r"Huber loss function $h_{\delta}(x)$")
>>> ax.set_xlim(-4, 4)
>>> ax.set_ylim(0, 8)
>>> plt.show()
../../_images/scipy-special-huber-1.png