kmeans2#
- scipy.cluster.vq.kmeans2(data, k, iter=10, thresh=1e-05, minit='random', missing='warn', check_finite=True, *, rng=None)[source]#
使用 k-means 演算法將一組觀察值分類到 k 個群集中。
此演算法嘗試最小化觀察值和質心之間的歐幾里德距離。包含幾種初始化方法。
- 參數:
- datandarray
一個 'M' 乘 'N' 的陣列,包含 'N' 維度的 'M' 個觀察值,或一個長度為 'M' 的陣列,包含 'M' 個 1 維觀察值。
- kint 或 ndarray
要形成的群集數量以及要產生的質心數量。如果 minit 初始化字串為 'matrix',或者如果改為提供 ndarray,則將其解釋為要使用的初始群集。
- iterint,選填
要執行的 k-means 演算法的迭代次數。請注意,這與 kmeans 函數的 iters 參數的含義不同。
- threshfloat,選填
(尚未使用)
- minitstr,選填
初始化方法。可用的方法有 'random'、'points'、'++' 和 'matrix'
'random':從高斯分佈產生 k 個質心,其平均值和變異數是從資料估計而來。
'points':從資料中隨機選擇 k 個觀察值(列)作為初始質心。
'++':根據 kmeans++ 方法選擇 k 個觀察值(小心播種)
'matrix':將 k 參數解釋為 k 乘 M(或 1 維資料的長度 k 陣列)的初始質心陣列。
- missingstr,選填
處理空群集的方法。可用的方法有 'warn' 和 'raise'
'warn':發出警告並繼續。
'raise':引發 ClusterError 並終止演算法。
- check_finitebool,選填
是否檢查輸入矩陣是否僅包含有限數字。停用可能會提高效能,但如果輸入包含無限值或 NaN,則可能會導致問題(崩潰、非終止)。預設值:True
- rng{None, int,
numpy.random.Generator
},選填 如果 rng 是透過關鍵字傳遞,則
numpy.random.Generator
以外的類型會傳遞給numpy.random.default_rng
以實例化Generator
。如果 rng 已經是Generator
實例,則會使用提供的實例。指定 rng 以獲得可重複的函數行為。如果此引數是透過位置傳遞,或 seed 是透過關鍵字傳遞,則適用於引數 seed 的舊版行為
如果 seed 為 None(或
numpy.random
),則會使用numpy.random.RandomState
單例。如果 seed 是整數,則會使用新的
RandomState
實例,並以 seed 作為種子。如果 seed 已經是
Generator
或RandomState
實例,則會使用該實例。
在版本 1.15.0 中變更: 作為從使用
numpy.random.RandomState
過渡到numpy.random.Generator
的 SPEC-007 一部分,此關鍵字已從 seed 變更為 rng。在過渡期間,這兩個關鍵字將繼續運作,但一次只能指定一個。在過渡期之後,使用 seed 關鍵字的函數呼叫將發出警告。seed 和 rng 的行為如上所述,但新程式碼中應僅使用 rng 關鍵字。
- 回傳值:
- centroidndarray
一個 'k' 乘 'N' 的質心陣列,在 k-means 的最後一次迭代中找到。
- labelndarray
label[i] 是第 i 個觀察值最接近的質心的代碼或索引。
另請參閱
參考文獻
[1]D. Arthur 和 S. Vassilvitskii,“k-means++: careful seeding 的優勢”,第十八屆 ACM-SIAM 離散演算法研討會論文集,2007 年。
範例
>>> from scipy.cluster.vq import kmeans2 >>> import matplotlib.pyplot as plt >>> import numpy as np
建立 z,一個形狀為 (100, 2) 的陣列,其中包含來自三個多變量常態分佈的樣本混合。
>>> rng = np.random.default_rng() >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45) >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30) >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25) >>> z = np.concatenate((a, b, c)) >>> rng.shuffle(z)
計算三個群集。
>>> centroid, label = kmeans2(z, 3, minit='points') >>> centroid array([[ 2.22274463, -0.61666946], # may vary [ 0.54069047, 5.86541444], [ 6.73846769, 4.01991898]])
每個群集中有多少個點?
>>> counts = np.bincount(label) >>> counts array([29, 51, 20]) # may vary
繪製群集。
>>> w0 = z[label == 0] >>> w1 = z[label == 1] >>> w2 = z[label == 2] >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0') >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1') >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2') >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids') >>> plt.axis('equal') >>> plt.legend(shadow=True) >>> plt.show()