Machine Learning

K-Means Clustering - wcss 로 기준점 찾기

yugyeong 2022. 12. 5. 12:24

 

K-Means Clustering

 

k-평균 알고리즘(K-means clustering algorithm)은 주어진 데이터를 k개의 클러스터로 묶는 알고리즘이다.

 

 

 

아래의 데이터프레임을 가지고 kmeans 를 이용하여 비슷한 군집끼리 묶을 것이다.

 

수입과 지출 데이터가 들어있는'Annual Income (k$)' 과 'Spending Score (1-100)' 컬럼을 변수 X 로 두었다.

 

 

가장 적합한 기준점을 찾기 위해 wcss 라는 리스트를 만들고, 반복문을 이용하여서 인공지능 Kmeans 에 훈련한 값들을 wcss 리스트에 넣어줄 것이다.

 

from sklearn.cluster import KMeans

 

wcss = []
for k in np.arange(1,10+1):
    Kmeans = KMeans(n_clusters= k, random_state= 5)
    Kmeans.fit(X)
    wcss.append(Kmeans.inertia_)

 

 

wcss 리스트에는 아래의 값들이 담겼다.

이제 각 클러스터의 갯수마다 구한 wcss 값을 차트로 나타낸다. => 이것을 엘보우 메소드라고 한다.

 

 

 

기준점은 특정 점 전에 급격하게 감소되고 특정점 전에 완만하게 감소한 점을 찾으면 된다.

아래의 그래프에서는 5 이다.

5가 클러스트의 갯수이다.

 

x = np.arange(1,10+1)
plt.plot( x, wcss )
plt.title('The Elbow Method')
plt.xlabel('Number of Clusters')
plt.ylabel('WCSS')
plt.show()

 

 

Kmeans() 괄호 안에 위에서 찾아낸 클러스트 군짓 갯수를 넣어주고 random_state 값을 넣어주어서 kmeans 변수에 저장을 하였다.

kmeans 로 X 의 값을 예측하도록 하고 변수 y_pred 저장을 한 후 확인을 해보면 아래와같은 결과가 나온다.

 

데이터들은 총 다섯개의 군집으로 묶였다.

 

kmeans = KMeans(n_clusters= 5, random_state= 5)

y_pred = kmeans.fit_predict(X)

y_pred

 

 

군집으로 묶은 데이터들로 차트를 그려보면 아래처럼 결과가 나온다.

 

import seaborn as 

sb.scatterplot(data= df, x='Annual Income (k$)', y= 'Spending Score (1-100)' )
plt.show()

 

 

 

plt.figure(figsize=[12,8])
plt.scatter(X.values[y_pred == 0, 0], X.values[y_pred == 0, 1], s = 100, c = 'red', label = 'Cluster 1')
plt.scatter(X.values[y_pred == 1, 0], X.values[y_pred == 1, 1], s = 100, c = 'blue', label = 'Cluster 2')
plt.scatter(X.values[y_pred == 2, 0], X.values[y_pred == 2, 1], s = 100, c = 'green', label = 'Cluster 3')
plt.scatter(X.values[y_pred == 3, 0], X.values[y_pred == 3, 1], s = 100, c = 'cyan', label = 'Cluster 4')
plt.scatter(X.values[y_pred == 4, 0], X.values[y_pred == 4, 1], s = 100, c = 'magenta', label = 'Cluster 5')
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s = 300, c = 'yellow', label = 'Centroids')
plt.title('Clusters of customers')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.legend()
plt.show()