%run kmeans
import pandas as pd
import numpy as np
When I started pursuing a career in data science I went online and read up on a number of the most popular models. One that came up consistently was Kmeans, so after researching for a while I decided to implement a simple Kmeans model using sklearn. After completing this project I believed I understood Kmeans pretty well, but that confidence was quickly dashed during one my early interviews when I was asked to describe the model in detail. At the time, I understood Kmeans to be an unsupervised clustering method where you used the mean of data points to determine clusters. While this explanation might be adequate for a data analyst, it leaves a lot to be desired from a proper data scientist. So to help all those out there who are starting down the path of becoming a data scientist, here is a more detailed explanation of Kmeans and Kmeans++.
First, we need to understand what clustering is since Kmeans belongs to this family of machine learning models. The goal of a clustering algorithm is to group or "cluster" similar observations into buckets. To the untrained, this description might sound the same as classification, but there are a number of key distinctions. The most important differnce is that classification is commonly a supervised learning problem, whereas clustering is a unsupervised learning problem. Without labels to teach the model, clustering doesn't attempt label the data, but rather simply identifies if the cluster of observations "look similar". A good example of this is the difference between Kmeans, which is clustering, and K nearest neighbors (KNN), which is classification. Both algorithms use a distance metric to help group or classify observations, but one is provided the target variable while the other clusters based on comparable independent variables of the given observations.
P.S.
As a side note, the other major difference is what "k" refers to. In Kmeans, "k" refers to the number of clusters. In KNN, "k" is the number of closest observations used to vote on what the label should be.
The first step of clustering by Kmeans is by calculating a centroid, which is some random point, then locating the points that are closest to it, and recomputing the average or "mean" of the points close to the old centroid. Lastly, we loop through this process till we converge on clusters that don't change when we try to recompute them. Ok, let's break that process down step by step.
First, we need K number of starting points or centroids. Typically our starting centroids are random numbers one for each independent variable we are using and that number is chosen from the distruibution of each given column. Now that we have K many centroids we iterate over all observations in our dataset and calculate the distance between our one observation and all of the centroids. The distance can be calculated a number of different ways based off preference and what your data looks like. We need these distances so we can determine which centroid is closest to our given observation and then put our observation in the cluster that is associated with that centroid. Now that we have assigned every observation to a cluster we recalculate our centroid by taking the mean of all the points in a given cluster. Now armed with our newly recaculated centroids we repeat that entire process again and again. Our algorithm finishes when the recalculated centroid is the same as the old centroid.
def kmeans(X:np.ndarray, k:int, tolerance=1e-2):
centroid_dict = {count:[0] for count, c in enumerate(centroids)}
same_centroid = 0
# tolerance is how we determine if the old and new centroid are similar enough to end the loop
while same_centroid <= tolerance:
old_centroid_dict = centroid_dict # to keep track of t and t-1
centroid_dict = {count:[] for count, c in enumerate(centroids)}
for point_idx, x in enumerate(X):
closest_dist = 9999 #some large number
for centroid_idx, c in enumerate(centroids): # which centroid point is closest to x?
dist = abs(np.linalg.norm(x-c))
if dist < closest_dist:
closest_dist = dist
closest_idx = centroid_idx
centroid_dict[closest_idx].append(point_idx) # assign x to the closest centroid
centroids = []
same_centroid = 0
for c in centroid_dict:
# did the cluster change?
same_centroid += abs(np.linalg.norm(np.array(centroid_dict[c])-np.array(old_centroid_dict[c])))
centroids.append(X[centroid_dict[c]].mean(axis=0)[0]) # recalculate centroid
clusters = [centroid_dict[key] for key in centroid_dict]
return centroids, clusters
There are a couple of ways I could optimize this function particularly if I used numpy to calculate the distances and centroids using matrix operations instead of nested for loops, but I thought it might be easier to see how it worked if I used a more naive approach.
Kmeans++ is an expensive but improved way of choosing your starting centroids for Kmeans. In Kmeans++ the theory is that centroids that start further away from each other converge better at the end. So to do this we choose one random point as our starting centroid and then find the furthest observation from that starting centroid, which becomes our next centroid. This loop continues until we have K number of centroids and then we start our normal Kmeans algorithm.
def kmeans_pp(X, k):
centroid = X[np.random.choice(X.shape[0], 1, replace=False), :]
while len(centroid) < k:
all_dist = []
for c in centroid: # which centroid point is closest to x?
obs_dist = []
for x in X:
dist = abs(np.linalg.norm(x-c))
obs_dist.append(dist)
all_dist.append(obs_dist)
all_dist = np.array(all_dist).min(axis=0)
new_centroid = np.argmax(all_dist)
centroid = np.vstack([centroid, X[new_centroid]])
return centroid
def kmeans_with_kmeansPP(X:np.ndarray, k:int, centroids=None, tolerance=1e-2):
if centroids == 'kmeans++':
centroids = kmeans_pp(X, k)
else:
centroids = X[np.random.choice(X.shape[0], k, replace=False), :]
centroid_dict = {count:[0] for count, c in enumerate(centroids)}
same_centroid = 0
while same_centroid <= tolerance:
old_centroid_dict = centroid_dict # to keep track of t and t-1
centroid_dict = {count:[] for count, c in enumerate(centroids)}
for point_idx, x in enumerate(X):
closest_dist = 9999 #some large number
for centroid_idx, c in enumerate(centroids): # which centroid point is closest to x?
dist = abs(np.linalg.norm(x-c))
if dist < closest_dist:
closest_dist = dist
closest_idx = centroid_idx
centroid_dict[closest_idx].append(point_idx) # assign x to the closest centroid
centroids = []
same_centroid = 0
for c in centroid_dict:
same_centroid += abs(np.linalg.norm(np.array(centroid_dict[c])-np.array(old_centroid_dict[c]))) # did the cluster change?
centroids.append(X[centroid_dict[c]].mean(axis=0)[0]) # recalculate centroid
clusters = [centroid_dict[key] for key in centroid_dict]
return centroids, clusters
One example of clustering is to compress information by only keeping the clusters instead of all the raw observations. This does mean you are losing some of the fine details but sometimes it is necessary in order to process huge amounts of disorganized data. A nice way of showing how clustering compresses information is by write some code that compresses an image by clustering color.
from PIL import Image
from matplotlib import pyplot as plt
img = Image.open('../cat_1.jpg')
img = img.convert("L") # grayscale
img.show()
my_img_arr = np.array(list(img.getdata(band=0)),float)
my_img_arr.shape = (img.size[1], img.size[0])
A_old = np.matrix(my_img_arr).astype(float)
row, col = A_old.shape
A_new = []
for r in range(row):
A_new.append([A_old[r,c] for c in range(col)])
A_new = np.array(A_new)
print(A_new)
def is_closest(pixel, centroids):
index = (np.abs(np.array(centroids) - pixel)).argmin()
return centroids[index]
def compress_img(k, img_matrix, original_image):
A_old = np.matrix(original_image).astype(float)
centroids, clusters = kmeans_with_kmeansPP(img_matrix, k=k, centroids='kmeans++', tolerance=.01)
row, col = A_old.shape
for r in range(row):
for c in range(col):
A_old[r,c] = is_closest(A_old[r,c], centroids)
plt.imshow(A_old, cmap='gray')
compress_img(32, A_new, my_img_arr)
compress_img(16, A_new, my_img_arr)
compress_img(8, A_new, my_img_arr)
compress_img(4, A_new, my_img_arr)
compress_img(2, A_new, my_img_arr)