Using K-Means Clustering for Image Segmentation

Including Python tutorial using cv2

Cierra Andaur
5 min readJan 20, 2021
Photo by Sander Boot on Unsplash

Have you ever wondered how a self-driving car can “see” in order to autopilot? They use a machine learning algorithm clustering method called image segmentation. Let’s break down some simple concepts that make the algorithm possible.

Want to read this story later? Save it in Journal.

First, some terminology:

Clustering is a technique of grouping data together with similar characteristics in order to identify groups. This can be useful for data analysis, recommender systems, search engines, spam filters, and image segmentation, just to name a few.

A centroid is a data point at the center of a cluster.

K-Means is a clustering method that aims to group (or cluster) observations into k-number of clusters in which each observation belongs to the cluster with the nearest mean. The below diagram portrays the process:

  1. Randomly initialize k-centroids (in this example, k=4)
  2. Each observation is assigned to its closest centroid to create a cluster.
  3. Centroids are then updated using the mean of each cluster (i.e. centroids move to the center of their cluster).
  4. Steps 1 & 2 continue to repeat until centroids stabilize.

You can also interact and play around with this process here.

To implement K-Means in Python, we use sklearn’s KMeans() function and specify the number of clusters with the parameter n_clusters= .

from sklearn.cluster import KMeans

k_means = KMeans(n_clusters=3)

k_means.fit(your_dataframe)

cluster_assignments = k_means.predict(your_dataframe)

Practical applications of K-Means Clustering

One of the most relatable every-day examples of clustering is what our email service does with spam mail. Gmail has an algorithm that identifies whether an email deserves to be in your inbox or go straight to the spam folder. Voila our two clusters: probably-spam and probably-not-spam. The algorithm analyzes the characteristics of the email and if it ticks enough of the “spam” boxes, it filters into your spam folder. You can even train your personal algorithm by labeling emails missed by the algorithm as spam yourself, which updates the criteria of what may be classified as spam.

We can also use K-Means for assembly line production. Let’s say a manufacturing company creates a product that needs to be inspected to determine if the product is defective (again we have two clusters: defective and normal). This inspection and labeling could be done by humans, but that route can be costly and take a long time. An alternative solution is to use an algorithm to look at images of products and label each as defective or normal. Think of each photo as unlabeled data that the model will then label as either defective or acceptable. This is an unsupervised learning clustering technique called anomaly detection.

Unsupervised learning takes a look at a data set with no labels, and searches for patterns. In our manufacturing example, the model will process the characteristics of each photo. When a photo comes along with different characteristics, it will be clustered into a different group: defective. In other words, the algorithm looks at unlabeled data without needing a human to label each one.

Image Segmentation

One example of clustering is image segmentation, which may be used in object detection and tracking systems. This method aims to change an image into a more meaningful one which may be interpreted by the machine. There are several types of image segmentation.

In semantic segmentation, pixels that are part of the same object type are assigned to the same segment. For example, when we talk about AI in a self-driving car like Tesla, we’re talking about semantic segmentation. Pedestrians all fall into one segment, traffic lights into another, cars into another, etcetera. The algorithm detects road layout, static infrastructure, and 3D objects by training neural networks with semantic segmentation. Cool!

Another type is instance segmentation, which clusters each individual object into its own segment. For example, each pedestrian would be its own segment.

Let’s start with something a little more simple: Color segmentation, which assigns similar colors to each segment. A practical application of color segmentation would be in analyzing satellite images. For example, we may want to measure how much forest or desert there is in an area, or how large a body of water is.

Let’s play with some color segmentation!

Read in your image and check out the shape.

img = imread(os.path.join("ladybug.jpg"))
img_convert = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #defining image to experiment with number of clusters
img.shape

convert the 3D image into a 2D matrix where each row is now a vector, then convert to a float

vectorized = img.reshape((-1,3))
vectorized = np.float32(vectorized)

Experiment with clusters!

# reshape array to get the long list of RGB colors and then cluster using KMeans()#image 2
K = 10
attempts=10
ret,label,center=cv2.kmeans(vectorized,K,None,criteria,attempts,cv2.KMEANS_PP_CENTERS)
center = np.uint8(center)
res = center[label.flatten()]
result_image1 = res.reshape((img_convert.shape))
#image 3
K = 4
attempts=10
ret,label,center=cv2.kmeans(vectorized,K,None,criteria,attempts,cv2.KMEANS_PP_CENTERS)
center = np.uint8(center)
res = center[label.flatten()]
result_image2 = res.reshape((img_convert.shape))
#image 4
K = 2
attempts=10
ret,label,center=cv2.kmeans(vectorized,K,None,criteria,attempts,cv2.KMEANS_PP_CENTERS)
center = np.uint8(center)
res = center[label.flatten()]
result_image3 = res.reshape((img_convert.shape))

Plot it out

import matplotlib.pyplot as pltfigure_size = 10
plt.figure(figsize=(figure_size,figure_size))
#original image
plt.subplot(2,2,1),plt.imshow(img)
plt.title('Original Image'), plt.xticks([]), plt.yticks([])
#image 2
plt.subplot(2,2,2),plt.imshow(result_image1)
plt.title('Segmented Image when K = 10'), plt.xticks([]), plt.yticks([])
#image 3
plt.subplot(2,2,3),plt.imshow(result_image2)
plt.title('Segmented Image when K = 4'), plt.xticks([]), plt.yticks([])
#image 4
plt.subplot(2,2,4),plt.imshow(result_image3)
plt.title('Segmented Image when K = 2'), plt.xticks([]), plt.yticks([])
plt.show()

You can see that 10 clusters works well for this image because there are only a few shades of each color.

For a more complicated image, the algorithm will require more clusters in order to pick out the ladybug. Take for example the below image which has a lot more going on in the background.

This image needs closer to 20 clusters in order to pick out the ladybug. The number of clusters will depend on your business problem. How much of the detail do you need to pick out?

📝 Save this story in Journal.

--

--

Cierra Andaur

Data Scientist | Analytics Nerd | Pythonista | Professional Question Asker |