scipy.cluster.vq.

kmeans2#

scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, rng=None)[source]#

使用 k-means 演算法將一組觀察值分類到 k 個群集中。

此演算法嘗試最小化觀察值和質心之間的歐幾里德距離。包含幾種初始化方法。

參數:
datandarray

一個 'M' 乘 'N' 的陣列,包含 'N' 維度的 'M' 個觀察值,或一個長度為 'M' 的陣列,包含 'M' 個 1 維觀察值。

kint 或 ndarray

要形成的群集數量以及要產生的質心數量。如果 minit 初始化字串為 'matrix',或者如果改為提供 ndarray,則將其解釋為要使用的初始群集。

iterint,選填

要執行的 k-means 演算法的迭代次數。請注意,這與 kmeans 函數的 iters 參數的含義不同。

threshfloat,選填

(尚未使用)

minitstr,選填

初始化方法。可用的方法有 'random'、'points'、'++' 和 'matrix'

'random':從高斯分佈產生 k 個質心,其平均值和變異數是從資料估計而來。

'points':從資料中隨機選擇 k 個觀察值(列)作為初始質心。

'++':根據 kmeans++ 方法選擇 k 個觀察值(小心播種)

'matrix':將 k 參數解釋為 k 乘 M(或 1 維資料的長度 k 陣列)的初始質心陣列。

missingstr,選填

處理空群集的方法。可用的方法有 'warn' 和 'raise'

'warn':發出警告並繼續。

'raise':引發 ClusterError 並終止演算法。

check_finitebool,選填

是否檢查輸入矩陣是否僅包含有限數字。停用可能會提高效能,但如果輸入包含無限值或 NaN,則可能會導致問題(崩潰、非終止)。預設值:True

rng{None, int, numpy.random.Generator},選填

如果 rng 是透過關鍵字傳遞,則 numpy.random.Generator 以外的類型會傳遞給 numpy.random.default_rng 以實例化 Generator。如果 rng 已經是 Generator 實例,則會使用提供的實例。指定 rng 以獲得可重複的函數行為。

如果此引數是透過位置傳遞,或 seed 是透過關鍵字傳遞,則適用於引數 seed 的舊版行為

  • 如果 seed 為 None(或 numpy.random),則會使用 numpy.random.RandomState 單例。

  • 如果 seed 是整數,則會使用新的 RandomState 實例,並以 seed 作為種子。

  • 如果 seed 已經是 GeneratorRandomState 實例,則會使用該實例。

在版本 1.15.0 中變更: 作為從使用 numpy.random.RandomState 過渡到 numpy.random.GeneratorSPEC-007 一部分,此關鍵字已從 seed 變更為 rng。在過渡期間,這兩個關鍵字將繼續運作,但一次只能指定一個。在過渡期之後,使用 seed 關鍵字的函數呼叫將發出警告。seedrng 的行為如上所述,但新程式碼中應僅使用 rng 關鍵字。

回傳值:
centroidndarray

一個 'k' 乘 'N' 的質心陣列,在 k-means 的最後一次迭代中找到。

labelndarray

label[i] 是第 i 個觀察值最接近的質心的代碼或索引。

另請參閱

kmeans

參考文獻

[1]

D. Arthur 和 S. Vassilvitskii,“k-means++: careful seeding 的優勢”,第十八屆 ACM-SIAM 離散演算法研討會論文集,2007 年。

範例

>>> from scipy.cluster.vq import kmeans2
>>> import matplotlib.pyplot as plt
>>> import numpy as np

建立 z,一個形狀為 (100, 2) 的陣列,其中包含來自三個多變量常態分佈的樣本混合。

>>> rng = np.random.default_rng()
>>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
>>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
>>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
>>> z = np.concatenate((a, b, c))
>>> rng.shuffle(z)

計算三個群集。

>>> centroid, label = kmeans2(z, 3, minit='points')
>>> centroid
array([[ 2.22274463, -0.61666946],  # may vary
       [ 0.54069047,  5.86541444],
       [ 6.73846769,  4.01991898]])

每個群集中有多少個點?

>>> counts = np.bincount(label)
>>> counts
array([29, 51, 20])  # may vary

繪製群集。

>>> w0 = z[label == 0]
>>> w1 = z[label == 1]
>>> w2 = z[label == 2]
>>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
>>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
>>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
>>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
>>> plt.axis('equal')
>>> plt.legend(shadow=True)
>>> plt.show()
../../_images/scipy-cluster-vq-kmeans2-1.png