99999久久久久久亚洲,欧美人与禽猛交狂配,高清日韩av在线影院,一个人在线高清免费观看,啦啦啦在线视频免费观看www

熱線(xiàn)電話(huà):13121318867

登錄
首頁(yè)精彩閱讀簡(jiǎn)單易學(xué)的機(jī)器學(xué)習(xí)算法—Mean Shift聚類(lèi)算法
簡(jiǎn)單易學(xué)的機(jī)器學(xué)習(xí)算法—Mean Shift聚類(lèi)算法
2017-03-25
收藏

簡(jiǎn)單易學(xué)的機(jī)器學(xué)習(xí)算法—Mean Shift聚類(lèi)算法

一、Mean Shift算法概述

Mean Shift算法,又稱(chēng)為均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后來(lái)由Yizong Cheng對(duì)其進(jìn)行擴(kuò)充,主要提出了兩點(diǎn)的改進(jìn):
定義了核函數(shù);
增加了權(quán)重系數(shù)。
核函數(shù)的定義使得偏移值對(duì)偏移向量的貢獻(xiàn)隨之樣本與被偏移點(diǎn)的距離的不同而不同。權(quán)重系數(shù)使得不同樣本的權(quán)重不同。Mean Shift算法在聚類(lèi),圖像平滑、分割以及視頻跟蹤等方面有廣泛的應(yīng)用。
二、Mean Shift算法的核心原理
2.1、核函數(shù)
在Mean Shift算法中引入核函數(shù)的目的是使得隨著樣本與被偏移點(diǎn)的距離不同,其偏移量對(duì)均值偏移向量的貢獻(xiàn)也不同。核函數(shù)是機(jī)器學(xué)習(xí)中常用的一種方式。核函數(shù)的定義如下所示:

并且滿(mǎn)足:
(1)、k是非負(fù)的
(2)、k是非增的
(3)、k是分段連續(xù)的
那么,函數(shù)K(x)就稱(chēng)為核函數(shù)。

常用的核函數(shù)有高斯核函數(shù)。高斯核函數(shù)如下所示:

其中,h稱(chēng)為帶寬(bandwidth),不同帶寬的核函數(shù)如下圖所示:

上圖的畫(huà)圖腳本如下所示:

'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import math

def cal_Gaussian(x, h=1):
    molecule = x * x
    denominator = 2 * h * h
    left = 1 / (math.sqrt(2 * math.pi) * h)
    return left * math.exp(-molecule / denominator)

x = []

for i in xrange(-40,40):
    x.append(i * 0.5);

score_1 = []
score_2 = []
score_3 = []
score_4 = []

for i in x:
    score_1.append(cal_Gaussian(i,1))
    score_2.append(cal_Gaussian(i,2))
    score_3.append(cal_Gaussian(i,3))
    score_4.append(cal_Gaussian(i,4))

plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")

plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()

2.2、Mean Shift算法的核心思想

2.2.1、基本原理

對(duì)于Mean Shift算法,是一個(gè)迭代的步驟,即先算出當(dāng)前點(diǎn)的偏移均值,將該點(diǎn)移動(dòng)到此偏移均值,然后以此為新的起始點(diǎn),繼續(xù)移動(dòng),直到滿(mǎn)足最終的條件。此過(guò)程可由下圖的過(guò)程進(jìn)行說(shuō)明(圖片來(lái)自參考文獻(xiàn)3):

步驟1:在指定的區(qū)域內(nèi)計(jì)算偏移均值(如下圖的黃色的圈)

步驟2:移動(dòng)該點(diǎn)到偏移均值點(diǎn)處

步驟3: 重復(fù)上述的過(guò)程(計(jì)算新的偏移均值,移動(dòng))

步驟4:滿(mǎn)足了最終的條件,即退出

從上述過(guò)程可以看出,在Mean Shift算法中,最關(guān)鍵的就是計(jì)算每個(gè)點(diǎn)的偏移均值,然后根據(jù)新計(jì)算的偏移均值更新點(diǎn)的位置。

2.2.2、基本的Mean Shift向量形式

對(duì)于給定的d維空間Rd中的n個(gè)樣本點(diǎn),則對(duì)于x點(diǎn),其Mean Shift向量的基本形式為:

其中,Sh指的是一個(gè)半徑為h的高維球區(qū)域,如上圖中的藍(lán)色的圓形區(qū)域。Sh的定義為:

這樣的一種基本的Mean Shift形式存在一個(gè)問(wèn)題:在Sh的區(qū)域內(nèi),每一個(gè)點(diǎn)對(duì)x的貢獻(xiàn)是一樣的。而實(shí)際上,這種貢獻(xiàn)與x到每一個(gè)點(diǎn)之間的距離是相關(guān)的。同時(shí),對(duì)于每一個(gè)樣本,其重要程度也是不一樣的。

2.2.3、改進(jìn)的Mean Shift向量形式

基于以上的考慮,對(duì)基本的Mean Shift向量形式中增加核函數(shù)和樣本權(quán)重,得到如下的改進(jìn)的Mean Shift向量形式:

其中:

G(x)是一個(gè)單位的核函數(shù)。H是一個(gè)正定的對(duì)稱(chēng)d×d矩陣,稱(chēng)為帶寬矩陣,其是一個(gè)對(duì)角陣。w(xi)?0是每一個(gè)樣本的權(quán)重。對(duì)角陣H的形式為:

上述的Mean Shift向量可以改寫(xiě)成:

Mean Shift向量Mh(x)是歸一化的概率密度梯度。
2.3、Mean Shift算法的解釋

在Mean Shift算法中,實(shí)際上是利用了概率密度,求得概率密度的局部最優(yōu)解。

2.3.1、概率密度梯度

對(duì)一個(gè)概率密度函數(shù)f(x),已知d維空間中n個(gè)采樣點(diǎn)xi,i=1,?,n,f(x)的核函數(shù)估計(jì)(也稱(chēng)為Parzen窗估計(jì))為:

其中
w(xi)?0是一個(gè)賦給采樣點(diǎn)xi的權(quán)重
K(x)是一個(gè)核函數(shù)
概率密度函數(shù)f(x)的梯度▽f(x)的估計(jì)為

,則有:

其中,第二個(gè)方括號(hào)中的就是Mean Shift向量,其與概率密度梯度成正比。

2.3.2、Mean Shift向量的修正

Mh(x)=∑ni=1G(∥∥xi?xh∥∥2)w(xi)xi∑ni=1G(xi?xh)w(xi)?x
記:,則上式變成:

Mh(x)=mh(x)+x
這與梯度上升的過(guò)程一致。

2.4、Mean Shift算法流程

Mean Shift算法的算法流程如下:

計(jì)算mh(x)
令x=mh(x)
如果∥mh(x)?x∥<ε,結(jié)束循環(huán),否則,重復(fù)上述步驟
三、實(shí)驗(yàn)

3.1、實(shí)驗(yàn)數(shù)據(jù)

實(shí)驗(yàn)數(shù)據(jù)如下圖所示(來(lái)自參考文獻(xiàn)1):

畫(huà)圖的代碼如下:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt

f = open("data")
x = []
y = []
for line in f.readlines():
    lines = line.strip().split("\t")
    if len(lines) == 2:
        x.append(float(lines[0]))
        y.append(float(lines[1]))
f.close()  

plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()
3.2、實(shí)驗(yàn)的源碼

#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''

import math
import sys
import numpy as np

MIN_DISTANCE = 0.000001#mini error

def load_data(path, feature_num=2):
    f = open(path)
    data = []
    for line in f.readlines():
        lines = line.strip().split("\t")
        data_tmp = []
        if len(lines) != feature_num:
            continue
        for i in xrange(feature_num):
            data_tmp.append(float(lines[i]))

        data.append(data_tmp)
    f.close()
    return data

def gaussian_kernel(distance, bandwidth):
    m = np.shape(distance)[0]
    right = np.mat(np.zeros((m, 1)))
    for i in xrange(m):
        right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
        right[i, 0] = np.exp(right[i, 0])
    left = 1 / (bandwidth * math.sqrt(2 * math.pi))

    gaussian_val = left * right
    return gaussian_val

def shift_point(point, points, kernel_bandwidth):
    points = np.mat(points)
    m,n = np.shape(points)
    #計(jì)算距離
    point_distances = np.mat(np.zeros((m,1)))
    for i in xrange(m):
        point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)

    #計(jì)算高斯核      
    point_weights = gaussian_kernel(point_distances, kernel_bandwidth)

    #計(jì)算分母
    all = 0.0
    for i in xrange(m):
        all += point_weights[i, 0]

    #均值偏移
    point_shifted = point_weights.T * points / all
    return point_shifted

def euclidean_dist(pointA, pointB):
    #計(jì)算pointA和pointB之間的歐式距離
    total = (pointA - pointB) * (pointA - pointB).T
    return math.sqrt(total)

def distance_to_group(point, group):
    min_distance = 10000.0
    for pt in group:
        dist = euclidean_dist(point, pt)
        if dist < min_distance:
            min_distance = dist
    return min_distance

def group_points(mean_shift_points):
    group_assignment = []
    m,n = np.shape(mean_shift_points)
    index = 0
    index_dict = {}
    for i in xrange(m):
        item = []
        for j in xrange(n):
            item.append(str(("%5.2f" % mean_shift_points[i, j])))

        item_1 = "_".join(item)
        print item_1
        if item_1 not in index_dict:
            index_dict[item_1] = index
            index += 1

    for i in xrange(m):
        item = []
                for j in xrange(n):
                        item.append(str(("%5.2f" % mean_shift_points[i, j])))

                item_1 = "_".join(item)
        group_assignment.append(index_dict[item_1])

    return group_assignment

def train_mean_shift(points, kenel_bandwidth=2):
    #shift_points = np.array(points)
    mean_shift_points = np.mat(points)
    max_min_dist = 1
    iter = 0
    m, n = np.shape(mean_shift_points)
    need_shift = [True] * m

    #cal the mean shift vector
    while max_min_dist > MIN_DISTANCE:
        max_min_dist = 0
        iter += 1
        print "iter : " + str(iter)
        for i in range(0, m):
            #判斷每一個(gè)樣本點(diǎn)是否需要計(jì)算偏置均值
            if not need_shift[i]:
                continue
            p_new = mean_shift_points[i]
            p_new_start = p_new
            p_new = shift_point(p_new, points, kenel_bandwidth)
            dist = euclidean_dist(p_new, p_new_start)

            if dist > max_min_dist:#record the max in all points
                max_min_dist = dist
            if dist < MIN_DISTANCE:#no need to move
                need_shift[i] = False

            mean_shift_points[i] = p_new
    #計(jì)算最終的group
    group = group_points(mean_shift_points)

    return np.mat(points), mean_shift_points, group

if __name__ == "__main__":
    #導(dǎo)入數(shù)據(jù)集
    path = "./data"
    data = load_data(path, 2)

    #訓(xùn)練,h=2
    points, shift_points, cluster = train_mean_shift(data, 2)

    for i in xrange(len(cluster)):
        print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])
3.3、實(shí)驗(yàn)的結(jié)果

經(jīng)過(guò)Mean Shift算法聚類(lèi)后的數(shù)據(jù)如下所示:


'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt

f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}

for line in f.readlines():
    lines = line.strip().split("\t")
    if len(lines) == 3:
        label = int(lines[2])
        if label == 0:
            data_1 = lines[0].strip().split(",")
            cluster_x_0.append(float(data_1[0]))
            cluster_y_0.append(float(data_1[1]))
            if label not in center_dict:
                center_dict[label] = 1
                data_2 = lines[1].strip().split(",")
                center_x.append(float(data_2[0]))
                center_y.append(float(data_2[1]))
        elif label == 1:
            data_1 = lines[0].strip().split(",")
            cluster_x_1.append(float(data_1[0]))
            cluster_y_1.append(float(data_1[1]))
            if label not in center_dict:
                center_dict[label] = 1
                data_2 = lines[1].strip().split(",")
                center_x.append(float(data_2[0]))
                center_y.append(float(data_2[1]))
        else:
            data_1 = lines[0].strip().split(",")
            cluster_x_2.append(float(data_1[0]))
            cluster_y_2.append(float(data_1[1]))
            if label not in center_dict:
                center_dict[label] = 1
                data_2 = lines[1].strip().split(",")
                center_x.append(float(data_2[0]))
                center_y.append(float(data_2[1]))    
f.close()

plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')數(shù)據(jù)分析師培訓(xùn)
#plt.legend(loc="best")
plt.show()

數(shù)據(jù)分析咨詢(xún)請(qǐng)掃描二維碼

若不方便掃碼,搜微信號(hào):CDAshujufenxi

數(shù)據(jù)分析師資訊
更多

OK
客服在線(xiàn)
立即咨詢(xún)
客服在線(xiàn)
立即咨詢(xún)
') } function initGt() { var handler = function (captchaObj) { captchaObj.appendTo('#captcha'); captchaObj.onReady(function () { $("#wait").hide(); }).onSuccess(function(){ $('.getcheckcode').removeClass('dis'); $('.getcheckcode').trigger('click'); }); window.captchaObj = captchaObj; }; $('#captcha').show(); $.ajax({ url: "/login/gtstart?t=" + (new Date()).getTime(), // 加隨機(jī)數(shù)防止緩存 type: "get", dataType: "json", success: function (data) { $('#text').hide(); $('#wait').show(); // 調(diào)用 initGeetest 進(jìn)行初始化 // 參數(shù)1:配置參數(shù) // 參數(shù)2:回調(diào),回調(diào)的第一個(gè)參數(shù)驗(yàn)證碼對(duì)象,之后可以使用它調(diào)用相應(yīng)的接口 initGeetest({ // 以下 4 個(gè)配置參數(shù)為必須,不能缺少 gt: data.gt, challenge: data.challenge, offline: !data.success, // 表示用戶(hù)后臺(tái)檢測(cè)極驗(yàn)服務(wù)器是否宕機(jī) new_captcha: data.new_captcha, // 用于宕機(jī)時(shí)表示是新驗(yàn)證碼的宕機(jī) product: "float", // 產(chǎn)品形式,包括:float,popup width: "280px", https: true // 更多配置參數(shù)說(shuō)明請(qǐng)參見(jiàn):http://docs.geetest.com/install/client/web-front/ }, handler); } }); } function codeCutdown() { if(_wait == 0){ //倒計(jì)時(shí)完成 $(".getcheckcode").removeClass('dis').html("重新獲取"); }else{ $(".getcheckcode").addClass('dis').html("重新獲取("+_wait+"s)"); _wait--; setTimeout(function () { codeCutdown(); },1000); } } function inputValidate(ele,telInput) { var oInput = ele; var inputVal = oInput.val(); var oType = ele.attr('data-type'); var oEtag = $('#etag').val(); var oErr = oInput.closest('.form_box').next('.err_txt'); var empTxt = '請(qǐng)輸入'+oInput.attr('placeholder')+'!'; var errTxt = '請(qǐng)輸入正確的'+oInput.attr('placeholder')+'!'; var pattern; if(inputVal==""){ if(!telInput){ errFun(oErr,empTxt); } return false; }else { switch (oType){ case 'login_mobile': pattern = /^1[3456789]\d{9}$/; if(inputVal.length==11) { $.ajax({ url: '/login/checkmobile', type: "post", dataType: "json", data: { mobile: inputVal, etag: oEtag, page_ur: window.location.href, page_referer: document.referrer }, success: function (data) { } }); } break; case 'login_yzm': pattern = /^\d{6}$/; break; } if(oType=='login_mobile'){ } if(!!validateFun(pattern,inputVal)){ errFun(oErr,'') if(telInput){ $('.getcheckcode').removeClass('dis'); } }else { if(!telInput) { errFun(oErr, errTxt); }else { $('.getcheckcode').addClass('dis'); } return false; } } return true; } function errFun(obj,msg) { obj.html(msg); if(msg==''){ $('.login_submit').removeClass('dis'); }else { $('.login_submit').addClass('dis'); } } function validateFun(pat,val) { return pat.test(val); }