MENU

K-Means 算法实现(Matlab)

November 30, 2018 • Read: 5992 • 数据挖掘与机器学习阅读设置

K-Means 算法具体内容可以参考我博客的相关文章,这里只使用 Matlab 对其进行实现,其他内容不多赘述

K-Means 算法展开目录

1. 生成随机样本点展开目录

首先利用 mvnrnd 函数生成 3 组满足高斯分布的数据,每组数据都是 100*2 的矩阵,也就相当于生成 300 个在坐标轴上的样本点

  • %% 第一组数据
  • mu1=[0 0]; %均值
  • S1=[0.1 0 ; 0 0.1]; %协方差
  • data1=mvnrnd(mu1,S1,100); %产生高斯分布数据
  • %% 第二组数据
  • mu2=[-1.25 1.25];
  • S2=[0.1 0 ; 0 0.1];
  • data2=mvnrnd(mu2,S2,100);
  • %% 第三组数据
  • mu3=[1.25 1.25];
  • S3=[0.1 0 ; 0 0.1];
  • data3=mvnrnd(mu3,S3,100);

mu1mu2mu3 是数据的均值,也就是你将每组点画在坐标轴上,其大致的中心位置坐标,例如对于上面的三组数据,中心点就分别为 (0,0),(-1.25,1.25),(1.25,1.25),画在图上效果如下图
作图代码如下:

  • %% 显示数据
  • plot(data1(:,1),data1(:,2),'b+');
  • hold on;
  • plot(data2(:,1),data2(:,2),'b+');
  • plot(data3(:,1),data3(:,2),'b+');

2. 初始化各矩阵展开目录

首先我们要将三个 100*2 的矩阵合并为一个 300*2 的矩阵 data = [data1;data2;data2]

然后初始化聚类中心,生成 N 行 2 列的零矩阵,这里的 N 是用户输入的想要聚为几类

还有就是要把 data 矩阵拷贝一份,尽量在算法执行过程中执行拷贝矩阵,而不去动 data

  • %% 初始化变量
  • %% 初始化工作
  • data = [data1;data2;data3];
  • [m,n] = size(data); % m = 300,n = 2
  • center = zeros(N,n);% 初始化聚类中心,生成N行n列的零矩阵
  • pattern = data; % 将整个数据拷贝到pattern矩阵中

3. 算法核心展开目录

一开始随机选取 300 个点中的 N 个点作为聚类中心(N 是用户输入的聚类个数)。300 个点分别计算到这 N 个中心点那一个最短,就将该点分为第几号。举个例子:

设有一个点的坐标是 (0,0),分别有 3 个中心点 (2,2),(1,1),(3,3),经过计算,(0,0) 到 (1,1) 的距离是最短的,因此将 (0,0) 这个点划分为第 2 类

300 个点全部划分完以后,假设用户输入的 N 是 3,划分成 60,90,150,然后计算 60 个点的中心点坐标(只要将 60 个点的 x 坐标加起来除以 60,然后将 y 坐标加起来除以 60,就能得到中心点),70 个点的中心坐标,150 个点的中心坐标,设这三个中心坐标为 $(x_a,y_a)$,$(x_b,y_b)$,$(x_c,y_c)$,计算这三个中心点与之前随机选的三个中心点的距离是否小于一个阈值,如果都小于,则说明分类成功;只要有一个不满足,首先将这些新的中心坐标替换原来的中心坐标,然后重新分类

  • for x = 1 : N
  • center(x,:) = data(randi(300,1),:); % 第一次随机产生聚类中心 randi返回1*1的(1,300)的数
  • end
  • while true
  • distence = zeros(1,N); % 产生1行N列的零矩阵
  • num = zeros(1,N); % 产生1行N列的零矩阵
  • new_center = zeros(N,n); % 产生N行n列的零矩阵
  • %% 将所有的点打上标签1 2 3...N
  • for x = 1 : m
  • for y = 1 : N
  • distence(y) = norm(data(x,:) - center(y,:)); % norm函数计算到每个类的距离
  • end
  • [~,temp] = min(distence); %求最小的距离 ~是距离值,temp是第几个
  • pattern(x,n + 1) = temp;
  • end
  • k = 0;
  • %% 将所有在同一类里的点坐标全部相加,计算新的中心坐标
  • for y = 1 : N
  • for x = 1 : m
  • if pattern(x,n + 1) == y
  • new_center(y,:) = new_center(y,:) + pattern(x,1:n);
  • num(y) = num(y) + 1;
  • end
  • end
  • new_center(y,:) = new_center(y,:) / num(y);
  • if norm(new_center(y,:) - center(y,:)) < 0.1
  • k = k + 1;
  • end
  • end
  • if k == N
  • break;
  • else
  • center = new_center;
  • end
  • end
  • [m, n] = size(pattern); %[m,n] = [300,3]

4. 绘制聚类后的数据点图展开目录

  • figure;
  • hold on;
  • for i = 1 : m
  • if pattern(i,n) == 1
  • plot(pattern(i,1),pattern(i,2),'r*');
  • plot(center(1,1),center(1,2),'ko');
  • elseif pattern(i,n) == 2
  • plot(pattern(i,1),pattern(i,2),'g*');
  • plot(center(2,1),center(2,2),'ko');
  • elseif pattern(i,n) == 3
  • plot(pattern(i,1),pattern(i,2),'b*');
  • plot(center(3,1),center(3,2),'ko');
  • elseif pattern(i,n) == 4
  • plot(pattern(i,1),pattern(i,2),'y*');
  • plot(center(4,1),center(4,2),'ko');
  • else
  • plot(pattern(i,1),pattern(i,2),'m*');
  • plot(center(5,1),center(5,2),'ko');
  • end
  • end

完整代码展开目录

  • clear;
  • clc;
  • N = input('请设置聚类数目:');%设置聚类数目
  • %% 第一组数据
  • mu1=[0 0]; %均值
  • S1=[0.1 0 ; 0 0.1]; %协方差
  • data1=mvnrnd(mu1,S1,100); %产生高斯分布数据
  • %% 第二组数据
  • mu2=[-1.25 1.25];
  • S2=[0.1 0 ; 0 0.1];
  • data2=mvnrnd(mu2,S2,100);
  • %% 第三组数据
  • mu3=[1.25 1.25];
  • S3=[0.1 0 ; 0 0.1];
  • data3=mvnrnd(mu3,S3,100);
  • %% 显示数据
  • plot(data1(:,1),data1(:,2),'b+');
  • hold on;
  • plot(data2(:,1),data2(:,2),'b+');
  • plot(data3(:,1),data3(:,2),'b+');
  • %% 初始化工作
  • data = [data1;data2;data3];
  • [m,n] = size(data); % m = 300,n = 2
  • center = zeros(N,n);% 初始化聚类中心,生成N行n列的零矩阵
  • pattern = data; % 将整个数据拷贝到pattern矩阵中
  • %% 算法
  • for x = 1 : N
  • center(x,:) = data(randi(300,1),:); % 第一次随机产生聚类中心 randi返回1*1的(1,300)的数
  • end
  • while true
  • distence = zeros(1,N); % 产生1行N列的零矩阵
  • num = zeros(1,N); % 产生1行N列的零矩阵
  • new_center = zeros(N,n); % 产生N行n列的零矩阵
  • %% 将所有的点打上标签1 2 3...N
  • for x = 1 : m
  • for y = 1 : N
  • distence(y) = norm(data(x,:) - center(y,:)); % norm函数计算到每个类的距离
  • end
  • [~,temp] = min(distence); %求最小的距离 ~是距离值,temp是第几个
  • pattern(x,n + 1) = temp;
  • end
  • k = 0;
  • %% 将所有在同一类里的点坐标全部相加,计算新的中心坐标
  • for y = 1 : N
  • for x = 1 : m
  • if pattern(x,n + 1) == y
  • new_center(y,:) = new_center(y,:) + pattern(x,1:n);
  • num(y) = num(y) + 1;
  • end
  • end
  • new_center(y,:) = new_center(y,:) / num(y);
  • if norm(new_center(y,:) - center(y,:)) < 0.1
  • k = k + 1;
  • end
  • end
  • if k == N
  • break;
  • else
  • center = new_center;
  • end
  • end
  • [m, n] = size(pattern); %[m,n] = [300,3]
  • %% 最后显示聚类后的数据
  • figure;
  • hold on;
  • for i = 1 : m
  • if pattern(i,n) == 1
  • plot(pattern(i,1),pattern(i,2),'r*');
  • plot(center(1,1),center(1,2),'ko');
  • elseif pattern(i,n) == 2
  • plot(pattern(i,1),pattern(i,2),'g*');
  • plot(center(2,1),center(2,2),'ko');
  • elseif pattern(i,n) == 3
  • plot(pattern(i,1),pattern(i,2),'b*');
  • plot(center(3,1),center(3,2),'ko');
  • elseif pattern(i,n) == 4
  • plot(pattern(i,1),pattern(i,2),'y*');
  • plot(center(4,1),center(4,2),'ko');
  • else
  • plot(pattern(i,1),pattern(i,2),'m*');
  • plot(center(5,1),center(5,2),'ko');
  • end
  • end

执行的 GIF 图如下:

存在的问题以及改进方法展开目录

这只是一个比较简单的 K-Means 聚类代码,其中可能存在两个问题:

  1. 死循环
  2. 聚类不准确

第一个问题产生的原因很简单,如果用笔算过 K-Means 就会知道,对于一个数据集,可能的聚类方式不止一种,并且存在确实无法达到所有的聚类中心差都小于阈值的情况。解决办法是加一个变量 times 用于记录执行了多少次 while 循环,当 times 达到一个很大的值而依旧没有停止程序,可以判断出现了死循环,干脆直接输出结果,不再计算。

第二个问题产生的效果图如下
对于右边的样本集,我们用肉眼观察很明显聚类应该如红框所示,但是使用 K-Means 聚类后得到的结果与预期差异较大,究其原因有很多种,具体解决办法是将阈值减小,以达到更加精确的聚类

Archives Tip
QR Code for this page
Tipping QR Code