
簡單易學的機器學習算法—Mean Shift聚類算法
一、Mean Shift算法概述
Mean Shift算法,又稱為均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后來由Yizong Cheng對其進行擴充,主要提出了兩點的改進:
定義了核函數(shù);
增加了權(quán)重系數(shù)。
核函數(shù)的定義使得偏移值對偏移向量的貢獻隨之樣本與被偏移點的距離的不同而不同。權(quán)重系數(shù)使得不同樣本的權(quán)重不同。Mean Shift算法在聚類,圖像平滑、分割以及視頻跟蹤等方面有廣泛的應用。
二、Mean Shift算法的核心原理
2.1、核函數(shù)
在Mean Shift算法中引入核函數(shù)的目的是使得隨著樣本與被偏移點的距離不同,其偏移量對均值偏移向量的貢獻也不同。核函數(shù)是機器學習中常用的一種方式。核函數(shù)的定義如下所示:
并且滿足:
(1)、k是非負的
(2)、k是非增的
(3)、k是分段連續(xù)的
那么,函數(shù)K(x)就稱為核函數(shù)。
常用的核函數(shù)有高斯核函數(shù)。高斯核函數(shù)如下所示:
其中,h稱為帶寬(bandwidth),不同帶寬的核函數(shù)如下圖所示:
上圖的畫圖腳本如下所示:
'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import math
def cal_Gaussian(x, h=1):
molecule = x * x
denominator = 2 * h * h
left = 1 / (math.sqrt(2 * math.pi) * h)
return left * math.exp(-molecule / denominator)
x = []
for i in xrange(-40,40):
x.append(i * 0.5);
score_1 = []
score_2 = []
score_3 = []
score_4 = []
for i in x:
score_1.append(cal_Gaussian(i,1))
score_2.append(cal_Gaussian(i,2))
score_3.append(cal_Gaussian(i,3))
score_4.append(cal_Gaussian(i,4))
plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")
plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()
2.2、Mean Shift算法的核心思想
2.2.1、基本原理
對于Mean Shift算法,是一個迭代的步驟,即先算出當前點的偏移均值,將該點移動到此偏移均值,然后以此為新的起始點,繼續(xù)移動,直到滿足最終的條件。此過程可由下圖的過程進行說明(圖片來自參考文獻3):
步驟1:在指定的區(qū)域內(nèi)計算偏移均值(如下圖的黃色的圈)
步驟2:移動該點到偏移均值點處
步驟3: 重復上述的過程(計算新的偏移均值,移動)
步驟4:滿足了最終的條件,即退出
從上述過程可以看出,在Mean Shift算法中,最關(guān)鍵的就是計算每個點的偏移均值,然后根據(jù)新計算的偏移均值更新點的位置。
2.2.2、基本的Mean Shift向量形式
對于給定的d維空間Rd中的n個樣本點,則對于x點,其Mean Shift向量的基本形式為:
其中,Sh指的是一個半徑為h的高維球區(qū)域,如上圖中的藍色的圓形區(qū)域。Sh的定義為:
這樣的一種基本的Mean Shift形式存在一個問題:在Sh的區(qū)域內(nèi),每一個點對x的貢獻是一樣的。而實際上,這種貢獻與x到每一個點之間的距離是相關(guān)的。同時,對于每一個樣本,其重要程度也是不一樣的。
2.2.3、改進的Mean Shift向量形式
基于以上的考慮,對基本的Mean Shift向量形式中增加核函數(shù)和樣本權(quán)重,得到如下的改進的Mean Shift向量形式:
其中:
G(x)是一個單位的核函數(shù)。H是一個正定的對稱d×d矩陣,稱為帶寬矩陣,其是一個對角陣。w(xi)?0是每一個樣本的權(quán)重。對角陣H的形式為:
上述的Mean Shift向量可以改寫成:
Mean Shift向量Mh(x)是歸一化的概率密度梯度。
2.3、Mean Shift算法的解釋
在Mean Shift算法中,實際上是利用了概率密度,求得概率密度的局部最優(yōu)解。
2.3.1、概率密度梯度
對一個概率密度函數(shù)f(x),已知d維空間中n個采樣點xi,i=1,?,n,f(x)的核函數(shù)估計(也稱為Parzen窗估計)為:
其中
w(xi)?0是一個賦給采樣點xi的權(quán)重
K(x)是一個核函數(shù)
概率密度函數(shù)f(x)的梯度▽f(x)的估計為
令,則有:
其中,第二個方括號中的就是Mean Shift向量,其與概率密度梯度成正比。
2.3.2、Mean Shift向量的修正
Mh(x)=∑ni=1G(∥∥xi?xh∥∥2)w(xi)xi∑ni=1G(xi?xh)w(xi)?x
記:,則上式變成:
Mh(x)=mh(x)+x
這與梯度上升的過程一致。
2.4、Mean Shift算法流程
Mean Shift算法的算法流程如下:
計算mh(x)
令x=mh(x)
如果∥mh(x)?x∥<ε,結(jié)束循環(huán),否則,重復上述步驟
三、實驗
3.1、實驗數(shù)據(jù)
實驗數(shù)據(jù)如下圖所示(來自參考文獻1):
畫圖的代碼如下:
'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data")
x = []
y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
x.append(float(lines[0]))
y.append(float(lines[1]))
f.close()
plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()
3.2、實驗的源碼
#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''
import math
import sys
import numpy as np
MIN_DISTANCE = 0.000001#mini error
def load_data(path, feature_num=2):
f = open(path)
data = []
for line in f.readlines():
lines = line.strip().split("\t")
data_tmp = []
if len(lines) != feature_num:
continue
for i in xrange(feature_num):
data_tmp.append(float(lines[i]))
data.append(data_tmp)
f.close()
return data
def gaussian_kernel(distance, bandwidth):
m = np.shape(distance)[0]
right = np.mat(np.zeros((m, 1)))
for i in xrange(m):
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
right[i, 0] = np.exp(right[i, 0])
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
gaussian_val = left * right
return gaussian_val
def shift_point(point, points, kernel_bandwidth):
points = np.mat(points)
m,n = np.shape(points)
#計算距離
point_distances = np.mat(np.zeros((m,1)))
for i in xrange(m):
point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)
#計算高斯核
point_weights = gaussian_kernel(point_distances, kernel_bandwidth)
#計算分母
all = 0.0
for i in xrange(m):
all += point_weights[i, 0]
#均值偏移
point_shifted = point_weights.T * points / all
return point_shifted
def euclidean_dist(pointA, pointB):
#計算pointA和pointB之間的歐式距離
total = (pointA - pointB) * (pointA - pointB).T
return math.sqrt(total)
def distance_to_group(point, group):
min_distance = 10000.0
for pt in group:
dist = euclidean_dist(point, pt)
if dist < min_distance:
min_distance = dist
return min_distance
def group_points(mean_shift_points):
group_assignment = []
m,n = np.shape(mean_shift_points)
index = 0
index_dict = {}
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
print item_1
if item_1 not in index_dict:
index_dict[item_1] = index
index += 1
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
group_assignment.append(index_dict[item_1])
return group_assignment
def train_mean_shift(points, kenel_bandwidth=2):
#shift_points = np.array(points)
mean_shift_points = np.mat(points)
max_min_dist = 1
iter = 0
m, n = np.shape(mean_shift_points)
need_shift = [True] * m
#cal the mean shift vector
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iter += 1
print "iter : " + str(iter)
for i in range(0, m):
#判斷每一個樣本點是否需要計算偏置均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth)
dist = euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:#record the max in all points
max_min_dist = dist
if dist < MIN_DISTANCE:#no need to move
need_shift[i] = False
mean_shift_points[i] = p_new
#計算最終的group
group = group_points(mean_shift_points)
return np.mat(points), mean_shift_points, group
if __name__ == "__main__":
#導入數(shù)據(jù)集
path = "./data"
data = load_data(path, 2)
#訓練,h=2
points, shift_points, cluster = train_mean_shift(data, 2)
for i in xrange(len(cluster)):
print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])
3.3、實驗的結(jié)果
經(jīng)過Mean Shift算法聚類后的數(shù)據(jù)如下所示:
'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 3:
label = int(lines[2])
if label == 0:
data_1 = lines[0].strip().split(",")
cluster_x_0.append(float(data_1[0]))
cluster_y_0.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
elif label == 1:
data_1 = lines[0].strip().split(",")
cluster_x_1.append(float(data_1[0]))
cluster_y_1.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
else:
data_1 = lines[0].strip().split(",")
cluster_x_2.append(float(data_1[0]))
cluster_y_2.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
f.close()
plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')數(shù)據(jù)分析師培訓
#plt.legend(loc="best")
plt.show()
數(shù)據(jù)分析咨詢請掃描二維碼
若不方便掃碼,搜微信號:CDAshujufenxi
LSTM 模型輸入長度選擇技巧:提升序列建模效能的關(guān)鍵? 在循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)家族中,長短期記憶網(wǎng)絡(luò)(LSTM)憑借其解決長序列 ...
2025-07-11CDA 數(shù)據(jù)分析師報考條件詳解與準備指南? ? 在數(shù)據(jù)驅(qū)動決策的時代浪潮下,CDA 數(shù)據(jù)分析師認證愈發(fā)受到矚目,成為眾多有志投身數(shù) ...
2025-07-11數(shù)據(jù)透視表中兩列相乘合計的實用指南? 在數(shù)據(jù)分析的日常工作中,數(shù)據(jù)透視表憑借其強大的數(shù)據(jù)匯總和分析功能,成為了 Excel 用戶 ...
2025-07-11尊敬的考生: 您好! 我們誠摯通知您,CDA Level I和 Level II考試大綱將于 2025年7月25日 實施重大更新。 此次更新旨在確保認 ...
2025-07-10BI 大數(shù)據(jù)分析師:連接數(shù)據(jù)與業(yè)務(wù)的價值轉(zhuǎn)化者? ? 在大數(shù)據(jù)與商業(yè)智能(Business Intelligence,簡稱 BI)深度融合的時代,BI ...
2025-07-10SQL 在預測分析中的應用:從數(shù)據(jù)查詢到趨勢預判? ? 在數(shù)據(jù)驅(qū)動決策的時代,預測分析作為挖掘數(shù)據(jù)潛在價值的核心手段,正被廣泛 ...
2025-07-10數(shù)據(jù)查詢結(jié)束后:分析師的收尾工作與價值深化? ? 在數(shù)據(jù)分析的全流程中,“query end”(查詢結(jié)束)并非工作的終點,而是將數(shù) ...
2025-07-10CDA 數(shù)據(jù)分析師考試:從報考到取證的全攻略? 在數(shù)字經(jīng)濟蓬勃發(fā)展的今天,數(shù)據(jù)分析師已成為各行業(yè)爭搶的核心人才,而 CDA(Certi ...
2025-07-09【CDA干貨】單樣本趨勢性檢驗:捕捉數(shù)據(jù)背后的時間軌跡? 在數(shù)據(jù)分析的版圖中,單樣本趨勢性檢驗如同一位耐心的偵探,專注于從單 ...
2025-07-09year_month數(shù)據(jù)類型:時間維度的精準切片? ? 在數(shù)據(jù)的世界里,時間是最不可或缺的維度之一,而year_month數(shù)據(jù)類型就像一把精準 ...
2025-07-09CDA 備考干貨:Python 在數(shù)據(jù)分析中的核心應用與實戰(zhàn)技巧? ? 在 CDA 數(shù)據(jù)分析師認證考試中,Python 作為數(shù)據(jù)處理與分析的核心 ...
2025-07-08SPSS 中的 Mann-Kendall 檢驗:數(shù)據(jù)趨勢與突變分析的有力工具? ? ? 在數(shù)據(jù)分析的廣袤領(lǐng)域中,準確捕捉數(shù)據(jù)的趨勢變化以及識別 ...
2025-07-08備戰(zhàn) CDA 數(shù)據(jù)分析師考試:需要多久?如何規(guī)劃? CDA(Certified Data Analyst)數(shù)據(jù)分析師認證作為國內(nèi)權(quán)威的數(shù)據(jù)分析能力認證 ...
2025-07-08LSTM 輸出不確定的成因、影響與應對策略? 長短期記憶網(wǎng)絡(luò)(LSTM)作為循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)的一種變體,憑借獨特的門控機制,在 ...
2025-07-07統(tǒng)計學方法在市場調(diào)研數(shù)據(jù)中的深度應用? 市場調(diào)研是企業(yè)洞察市場動態(tài)、了解消費者需求的重要途徑,而統(tǒng)計學方法則是市場調(diào)研數(shù) ...
2025-07-07CDA數(shù)據(jù)分析師證書考試全攻略? 在數(shù)字化浪潮席卷全球的當下,數(shù)據(jù)已成為企業(yè)決策、行業(yè)發(fā)展的核心驅(qū)動力,數(shù)據(jù)分析師也因此成為 ...
2025-07-07剖析 CDA 數(shù)據(jù)分析師考試題型:解鎖高效備考與答題策略? CDA(Certified Data Analyst)數(shù)據(jù)分析師考試作為衡量數(shù)據(jù)專業(yè)能力的 ...
2025-07-04SQL Server 字符串截取轉(zhuǎn)日期:解鎖數(shù)據(jù)處理的關(guān)鍵技能? 在數(shù)據(jù)處理與分析工作中,數(shù)據(jù)格式的規(guī)范性是保證后續(xù)分析準確性的基礎(chǔ) ...
2025-07-04CDA 數(shù)據(jù)分析師視角:從數(shù)據(jù)迷霧中探尋商業(yè)真相? 在數(shù)字化浪潮席卷全球的今天,數(shù)據(jù)已成為企業(yè)決策的核心驅(qū)動力,CDA(Certifie ...
2025-07-04CDA 數(shù)據(jù)分析師:開啟數(shù)據(jù)職業(yè)發(fā)展新征程? ? 在數(shù)據(jù)成為核心生產(chǎn)要素的今天,數(shù)據(jù)分析師的職業(yè)價值愈發(fā)凸顯。CDA(Certified D ...
2025-07-03