Image segmentation based on K-means (detailed explanation of python code)

This article mainly draws on the following articles, and some codes below are explained in detail in this article

K-means (image segmentation, python) of the ten classic machine learning algorithms Random algorithm) KNN (K Nearest Neighbor) K nearest neighbor (supervised algorithm, classification algorithm) logistic regression (classification algorithm) decision tree (supervised algorithm, probability algorithm) random forest (the simplest of integrated algorithms, model fusion algorithm) simple Bayesian EM algorithm Adaboost (one of the integrated algorithms) SVM Markov 1. Algorithm process (1) Random selection... https://blog.csdn.net/jizhidexiaoming/article/details/89214614?spm=1001.2014. 3001.5506

1. "K" and "means"

        K: There are k centroids (clusters).

        means: The centroid is the mean of all points in a cluster.

        K-means belongs to hard clustering. Hard clustering means that data can only belong to one cluster, as opposed to soft clustering: data can belong to multiple classes to varying degrees.

2. Algorithm steps

        S1: Select the initial centroid: randomly select K points from the sample points as the centroid.

        S2: Classify all sample points: calculate the distances from all sample points to K centroids, and divide them into the clusters where the cluster centers with the closest distances are located.

        S3: Re-determine the centroid: new centroid = mean of all points in the cluster.

        S4: Loop update: Repeat steps S2 and S3 until the center of mass does not change any more.

3. Flowchart

4. Code

import numpy as np
import matplotlib.pyplot as plt
img = plt.imread('1.jpg')
row = img.shape[0]
col = img.shape[1]
plt.subplot(121)
plt.imshow(img)
	

def knn(data, iter, k):
    data = data.reshape(-1, 3)
    data = np.column_stack((data, np.ones(row*col)))
    # 1.随机产生初始簇心
    cluster_center = data[np.random.choice(row*col, k)]
    # 2.分类
    distance = [[] for i in range(k)]
    for i in range(iter):
        print("迭代次数:", i)
        # 2.1距离计算
        for j in range(k):
            distance[j] = np.sqrt(np.sum((data - cluster_center[j])**2, axis=1))
        # 2.2归类
        data[:, 3] = np.argmin(distance, axis=0)
        # 3.计算新簇心
        for j in range(k):
            cluster_center[j] = np.mean(data[data[:, 3] == j], axis=0)
    return data[:, 3]


if __name__ == "__main__":
    image_show = knn(img, 100, 2)
    image_show = image_show.reshape(row, col)
    plt.subplot(122)
    plt.imshow(image_show, cmap='gray')
    plt.show()

5. How to use the code?

        The first step is to put the image to be divided into the project file, like this:

         The second step, the third line of the code: 'img = plt.imread()', add the picture path in the brackets: "picture name. file format", for example, 1.jpeg, 2.bmp and so on.

6. Running results

1. Iterate 100 times, the total number of clusters = 5 

2. Iterate 10 times, the total number of clusters = 2

7. Core code explanation

7.1 Image information processing process 

7.1.1 Picture Three-dimensional array (picture height, picture width, 3)

	img = plt.imread(‘path’)

7.1.2 Three-dimensional array (image height, image width, 3) two-dimensional array (image width*height, 3)

	img = img.reshape(-1, 3)

7.1.3 (-1, 3) array  (-1, 4) array

Add 1 column to store classification information

	img = np.column__stack((img, np.ones(row*col)))

7.1.4 Return result of kmeans algorithm

Only returns the classification results of all pixels

	return img[:, 3]

7.2 Representing pictures with arrays

7.2.1 Describing pixels

        A one-dimensional array [R, G, B] can represent a pixel, the content of the array is the value of the three primary colors, and the size of the array element ∈ [0, 255].

7.2.2 Describing pictures

        A three-dimensional array can represent a picture. If the resolution of the picture is m*n, then the shape of the array is (n, m, 3). This three-dimensional array describes the RGB information of all pixels (m*n in total) of the image.

(n, m, 3):n个形状为(m, 3)的2维数组。
n:图片的高
m:图片的宽
3:[RGB]长=3

7.3plt.imread()

7.3.1 Syntax

img = plt.read(‘图片路径’)

7.3.2 Function

        Extract the information of the picture into an array (convert the picture into an array).

7.3.3 Return value type

        Return value type: numpy array

 <class 'numpy.ndarray'>

7.3.4 Return value content

        Suppose image resolution=m*n, width*height

  1. If it is a grayscale image, return a 2D array of (n, m) shape.
  2. If it is an RGB image, return a three-dimensional array of shape (n, m, 3).

7.3.5 Examples

        Read a photo and get a 3D matrix containing all the information about that picture.

        The resolution of the photo=371*543

import numpy as np
import matplotlib.pyplot as plt
img = plt.imread('1.jpg')   # 读入图片
print(img.shape, type(img))
print(img) 
plt.imshow(img) # 利用数组显示图片
plt.show()
输出:(543, 371, 3) <class 'numpy.ndarray'>
输出:img记录着所有像素点的RGB值:

 7.4plt.imshow()

7.4.1 Function: Array -> Image

        Map the RGB information stored in the array into a color image.

7.4.2 Syntax

plt.imshow(img,  cmap='gray') 
plt.imshow(img) # hsv图片
plt.show()

        img is a three-dimensional array containing all the information of the image.

        The cmap parameter in plt.imshow() has two values: gray (grayscale representation), hsv (hsv color space), and the default value is hsv.

7.4.3 Array shape

        (1) Three-dimensional: The array extracted from the image is three-dimensional and can be displayed normally by plt.imshow().

        (2) Two-dimensional: The result of image segmentation is a two-dimensional array, and each element of the array corresponds to a type of pixel. If the image has 3 types of clusters, then the value of the two-dimensional array is [0, 1, 2], representing different types of belonging.

7.4.4 Features of pictures displayed by different array shapes

        Whether it is a two-dimensional or three-dimensional array, when expressed by plt.imshow(img), if the cmap parameter is not set, the default is hsv.

        In particular, the color type of the picture displayed by the two-dimensional array is the same as the number of possible values ​​of the elements of the two-dimensional array, so that the segmentation results can be distinguished. Below, the segmentation of the two clusters.

7.5plt.show()

7.5.1 Function

        display image.

7.5.2 Features: Blocking

        When plt.show() is executed, the program is blocked at this line and does not proceed downward until the image window is closed, the program proceeds downward.

7.6 Randomly generate cluster hearts

7.6.1 Two-dimensional array indexing

        In general, for a two-dimensional array, use 1 or 1 pair of numbers to index a row or an element of the array respectively.

        It can also be indexed with a vector to get an array of specified rows.

        Next, index a two-dimensional array with a vector:

import numpy as np
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
print( x[0] ) #[1, 2]
print( x[0,0,] ) # 1
print( x[[0, 1, 2]] ) #[[1, 2], [3, 4], [5, 6]]

7.6.2 Generate k cluster hearts randomly

        First randomly generate a vector of length k from [0, row*col-1], and then index the vector in data to get k random points

distance = np.sqrt(np.sum((x - y)**2, axis=1))

as cluster hearts.

cluster_center = data[np.random.choice(row*col, k)]
#row = data.shape[0]
#col = data.shape[1]

7.7. Euclidean distance

7.7.1 Definition of Euclidean distance

7.7.2 Implementation

distance = np.sqrt(np.sum((x - y)**2, axis=1))

7.7.3 Note: Arrays in the same column but in different rows can be added and subtracted.

        Calculate the image RGB array and a cluster center distance, the two shapes are: [n, 3] and [1, 3].

import numpy as np
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([1, 2])
print(x-y)
输出:
 

 

7.8 Classification

7.8.1 Ideas

        Create a new column for the three-dimensional array, and put the classification information of the pixels. If it belongs to the kth cluster, the classification information is stored as k.

        The final result returned by K-means: the information corresponding to each pixel is only which class it belongs to, and the value range is [0, k-1].

7.8.2 Implementation

        np.argmin: Returns the index of the smallest element in a row/column.

data[:, 3] = np.argmin(distance, axis=0)

7.9 Calculating new cluster centers

7.9.1 Principle

        Cluster center = the average value of all points in the cluster, and the cluster center is not necessarily the RGB value of the actual pixel.

7.9.2 Implementation

for j in range(k): # k为簇的个数
	cluster.center[j] = np.mean(data[data[:, 3] == j], axis=0)

        Among them, data[data[:, 3] == j]: returns all rows with the specified element.

Guess you like

Origin blog.csdn.net/marujie123/article/details/125721608