
python實(shí)現(xiàn)隨機(jī)森林random forest的原理及方法
想通過(guò)隨機(jī)森林來(lái)獲取數(shù)據(jù)的主要特征
1、理論
隨機(jī)森林是一個(gè)高度靈活的機(jī)器學(xué)習(xí)方法,擁有廣泛的應(yīng)用前景,從市場(chǎng)營(yíng)銷到醫(yī)療保健保險(xiǎn)。 既可以用來(lái)做市場(chǎng)營(yíng)銷模擬的建模,統(tǒng)計(jì)客戶來(lái)源,保留和流失。也可用來(lái)預(yù)測(cè)疾病的風(fēng)險(xiǎn)和病患者的易感性。
根據(jù)個(gè)體學(xué)習(xí)器的生成方式,目前的集成學(xué)習(xí)方法大致可分為兩大類,即個(gè)體學(xué)習(xí)器之間存在強(qiáng)依賴關(guān)系,必須串行生成的序列化方法,以及個(gè)體學(xué)習(xí)器間不存在強(qiáng)依賴關(guān)系,可同時(shí)生成的并行化方法;
前者的代表是Boosting,后者的代表是Bagging和“隨機(jī)森林”(Random
Forest)
隨機(jī)森林在以決策樹為基學(xué)習(xí)器構(gòu)建Bagging集成的基礎(chǔ)上,進(jìn)一步在決策樹的訓(xùn)練過(guò)程中引入了隨機(jī)屬性選擇(即引入隨機(jī)特征選擇)。
簡(jiǎn)單來(lái)說(shuō),隨機(jī)森林就是對(duì)決策樹的集成,但有兩點(diǎn)不同:
(2)特征選取的差異性:每個(gè)決策樹的n個(gè)分類特征是在所有特征中隨機(jī)選擇的(n是一個(gè)需要我們自己調(diào)整的參數(shù))
隨機(jī)森林,簡(jiǎn)單理解, 比如預(yù)測(cè)salary,就是構(gòu)建多個(gè)決策樹job,age,house,然后根據(jù)要預(yù)測(cè)的量的各個(gè)特征(teacher,39,suburb)分別在對(duì)應(yīng)決策樹的目標(biāo)值概率(salary<5000,salary>=5000),從而,確定預(yù)測(cè)量的發(fā)生概率(如,預(yù)測(cè)出P(salary<5000)=0.3).
隨機(jī)森林是一個(gè)可做能夠回歸和分類。 它具備處理大數(shù)據(jù)的特性,而且它有助于估計(jì)或變量是非常重要的基礎(chǔ)數(shù)據(jù)建模。
參數(shù)說(shuō)明:
最主要的兩個(gè)參數(shù)是n_estimators和max_features。
n_estimators:表示森林里樹的個(gè)數(shù)。理論上是越大越好。但是伴隨著就是計(jì)算時(shí)間的增長(zhǎng)。但是并不是取得越大就會(huì)越好,預(yù)測(cè)效果最好的將會(huì)出現(xiàn)在合理的樹個(gè)數(shù)。
max_features:隨機(jī)選擇特征集合的子集合,并用來(lái)分割節(jié)點(diǎn)。子集合的個(gè)數(shù)越少,方差就會(huì)減少的越快,但同時(shí)偏差就會(huì)增加的越快。根據(jù)較好的實(shí)踐經(jīng)驗(yàn)。如果是回歸問(wèn)題則:
max_features=n_features,如果是分類問(wèn)題則max_features=sqrt(n_features)。
如果想獲取較好的結(jié)果,必須將max_depth=None,同時(shí)min_sample_split=1。
同時(shí)還要記得進(jìn)行cross_validated(交叉驗(yàn)證),除此之外記得在random forest中,bootstrap=True。但在extra-trees中,bootstrap=False。
2、隨機(jī)森林python實(shí)現(xiàn)
2.1Demo1
實(shí)現(xiàn)隨機(jī)森林基本功能
#隨機(jī)森林
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
import numpy as np
from sklearn.datasets import load_iris
iris=load_iris()
#print iris#iris的4個(gè)屬性是:萼片寬度 萼片長(zhǎng)度 花瓣寬度 花瓣長(zhǎng)度 標(biāo)簽是花的種類:setosa versicolour virginica
print(iris['target'].shape)
rf=RandomForestRegressor()#這里使用了默認(rèn)的參數(shù)設(shè)置
rf.fit(iris.data[:150],iris.target[:150])#進(jìn)行模型的訓(xùn)練
#隨機(jī)挑選兩個(gè)預(yù)測(cè)不相同的樣本
instance=iris.data[[100,109]]
print(instance)
rf.predict(instance[[0]])
print('instance 0 prediction;',rf.predict(instance[[0]]))
print( 'instance 1 prediction;',rf.predict(instance[[1]]))
print(iris.target[100],iris.target[109])
運(yùn)行結(jié)果
(150,)
[[ 6.3 3.3 6. 2.5]
[ 7.2 3.6 6.1 2.5]]
instance 0 prediction; [ 2.]
instance 1 prediction; [ 2.]
2 2
2.2 Demo2
3種方法的比較
#random forest test
from sklearn.model_selection import cross_val_score
from sklearn.datasets import make_blobs
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.tree import DecisionTreeClassifier
X, y = make_blobs(n_samples=10000, n_features=10, centers=100,random_state=0)
clf = DecisionTreeClassifier(max_depth=None, min_samples_split=2,random_state=0)
scores = cross_val_score(clf, X, y)
print(scores.mean())
clf = RandomForestClassifier(n_estimators=10, max_depth=None,min_samples_split=2, random_state=0)
scores = cross_val_score(clf, X, y)
print(scores.mean())
clf = ExtraTreesClassifier(n_estimators=10, max_depth=None,min_samples_split=2, random_state=0)
scores = cross_val_score(clf, X, y)
print(scores.mean())
運(yùn)行結(jié)果:
0.979408793821
0.999607843137
0.999898989899
2.3 Demo3-實(shí)現(xiàn)特征選擇
#隨機(jī)森林2
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
import numpy as np
from sklearn.datasets import load_iris
iris=load_iris()
from sklearn.model_selection import cross_val_score, ShuffleSplit
X = iris["data"]
Y = iris["target"]
names = iris["feature_names"]
rf = RandomForestRegressor()
scores = []
for i in range(X.shape[1]):
score = cross_val_score(rf, X[:, i:i+1], Y, scoring="r2",
cv=ShuffleSplit(len(X), 3, .3))
scores.append((round(np.mean(score), 3), names[i]))
print(sorted(scores, reverse=True))
運(yùn)行結(jié)果:
[(0.89300000000000002, 'petal width (cm)'), (0.82099999999999995, 'petal length
(cm)'), (0.13, 'sepal length (cm)'), (-0.79100000000000004, 'sepal width (cm)')]
2.4 demo4-隨機(jī)森林
本來(lái)想利用以下代碼來(lái)構(gòu)建隨機(jī)隨機(jī)森林決策樹,但是,遇到的問(wèn)題是,程序一直在運(yùn)行,無(wú)法響應(yīng),還需要調(diào)試。
#隨機(jī)森林4
#coding:utf-8
import csv
from random import seed
from random import randrange
from math import sqrt
def loadCSV(filename):#加載數(shù)據(jù),一行行的存入列表
dataSet = []
with open(filename, 'r') as file:
csvReader = csv.reader(file)
for line in csvReader:
dataSet.append(line)
return dataSet
# 除了標(biāo)簽列,其他列都轉(zhuǎn)換為float類型
def column_to_float(dataSet):
featLen = len(dataSet[0]) - 1
for data in dataSet:
for column in range(featLen):
data[column] = float(data[column].strip())
# 將數(shù)據(jù)集隨機(jī)分成N塊,方便交叉驗(yàn)證,其中一塊是測(cè)試集,其他四塊是訓(xùn)練集
def spiltDataSet(dataSet, n_folds):
fold_size = int(len(dataSet) / n_folds)
dataSet_copy = list(dataSet)
dataSet_spilt = []
for i in range(n_folds):
fold = []
while len(fold) < fold_size: # 這里不能用if,if只是在第一次判斷時(shí)起作用,while執(zhí)行循環(huán),直到條件不成立
index = randrange(len(dataSet_copy))
fold.append(dataSet_copy.pop(index)) # pop() 函數(shù)用于移除列表中的一個(gè)元素(默認(rèn)最后一個(gè)元素),并且返回該元素的值。
dataSet_spilt.append(fold)
return dataSet_spilt
# 構(gòu)造數(shù)據(jù)子集
def get_subsample(dataSet, ratio):
subdataSet = []
lenSubdata = round(len(dataSet) * ratio)#返回浮點(diǎn)數(shù)
while len(subdataSet) < lenSubdata:
index = randrange(len(dataSet) - 1)
subdataSet.append(dataSet[index])
# print len(subdataSet)
return subdataSet
# 分割數(shù)據(jù)集
def data_spilt(dataSet, index, value):
left = []
right = []
for row in dataSet:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right
# 計(jì)算分割代價(jià)
def spilt_loss(left, right, class_values):
loss = 0.0
for class_value in class_values:
left_size = len(left)
if left_size != 0: # 防止除數(shù)為零
prop = [row[-1] for row in left].count(class_value) / float(left_size)
loss += (prop * (1.0 - prop))
right_size = len(right)
if right_size != 0:
prop = [row[-1] for row in right].count(class_value) / float(right_size)
loss += (prop * (1.0 - prop))
return loss
# 選取任意的n個(gè)特征,在這n個(gè)特征中,選取分割時(shí)的最優(yōu)特征
def get_best_spilt(dataSet, n_features):
features = []
class_values = list(set(row[-1] for row in dataSet))
b_index, b_value, b_loss, b_left, b_right = 999, 999, 999, None, None
while len(features) < n_features:
index = randrange(len(dataSet[0]) - 1)
if index not in features:
features.append(index)
# print 'features:',features
for index in features:#找到列的最適合做節(jié)點(diǎn)的索引,(損失最?。?
for row in dataSet:
left, right = data_spilt(dataSet, index, row[index])#以它為節(jié)點(diǎn)的,左右分支
loss = spilt_loss(left, right, class_values)
if loss < b_loss:#尋找最小分割代價(jià)
b_index, b_value, b_loss, b_left, b_right = index, row[index], loss, left, right
# print b_loss
# print type(b_index)
return {'index': b_index, 'value': b_value, 'left': b_left, 'right': b_right}
# 決定輸出標(biāo)簽
def decide_label(data):
output = [row[-1] for row in data]
return max(set(output), key=output.count)
# 子分割,不斷地構(gòu)建葉節(jié)點(diǎn)的過(guò)程對(duì)對(duì)對(duì)
def sub_spilt(root, n_features, max_depth, min_size, depth):
left = root['left']
# print left
right = root['right']
del (root['left'])
del (root['right'])
# print depth
if not left or not right:
root['left'] = root['right'] = decide_label(left + right)
# print 'testing'
return
if depth > max_depth:
root['left'] = decide_label(left)
root['right'] = decide_label(right)
return
if len(left) < min_size:
root['left'] = decide_label(left)
else:
root['left'] = get_best_spilt(left, n_features)
# print 'testing_left'
sub_spilt(root['left'], n_features, max_depth, min_size, depth + 1)
if len(right) < min_size:
root['right'] = decide_label(right)
else:
root['right'] = get_best_spilt(right, n_features)
# print 'testing_right'
sub_spilt(root['right'], n_features, max_depth, min_size, depth + 1)
# 構(gòu)造決策樹
def build_tree(dataSet, n_features, max_depth, min_size):
root = get_best_spilt(dataSet, n_features)
sub_spilt(root, n_features, max_depth, min_size, 1)
return root
# 預(yù)測(cè)測(cè)試集結(jié)果
def predict(tree, row):
predictions = []
if row[tree['index']] < tree['value']:
if isinstance(tree['left'], dict):
return predict(tree['left'], row)
else:
return tree['left']
else:
if isinstance(tree['right'], dict):
return predict(tree['right'], row)
else:
return tree['right']
# predictions=set(predictions)
def bagging_predict(trees, row):
predictions = [predict(tree, row) for tree in trees]
return max(set(predictions), key=predictions.count)
# 創(chuàng)建隨機(jī)森林
def random_forest(train, test, ratio, n_feature, max_depth, min_size, n_trees):
trees = []
for i in range(n_trees):
train = get_subsample(train, ratio)#從切割的數(shù)據(jù)集中選取子集
tree = build_tree(train, n_features, max_depth, min_size)
# print 'tree %d: '%i,tree
trees.append(tree)
# predict_values = [predict(trees,row) for row in test]
predict_values = [bagging_predict(trees, row) for row in test]
return predict_values
# 計(jì)算準(zhǔn)確率
def accuracy(predict_values, actual):
correct = 0
for i in range(len(actual)):
if actual[i] == predict_values[i]:
correct += 1
return correct / float(len(actual))
if __name__ == '__main__':
seed(1)
dataSet = loadCSV(r'G:\0研究生\tianchiCompetition\訓(xùn)練小樣本2.csv')
column_to_float(dataSet)
n_folds = 5
max_depth = 15
min_size = 1
ratio = 1.0
# n_features=sqrt(len(dataSet)-1)
n_features = 15
n_trees = 10
folds = spiltDataSet(dataSet, n_folds)#先是切割數(shù)據(jù)集
scores = []
for fold in folds:
train_set = folds[
:] # 此處不能簡(jiǎn)單地用train_set=folds,這樣用屬于引用,那么當(dāng)train_set的值改變的時(shí)候,folds的值也會(huì)改變,所以要用復(fù)制的形式。(L[:])能夠復(fù)制序列,D.copy() 能夠復(fù)制字典,list能夠生成拷貝 list(L)
train_set.remove(fold)#選好訓(xùn)練集
# print len(folds)
train_set = sum(train_set, []) # 將多個(gè)fold列表組合成一個(gè)train_set列表
# print len(train_set)
test_set = []
for row in fold:
row_copy = list(row)
row_copy[-1] = None
test_set.append(row_copy)
# for row in test_set:
# print row[-1]
actual = [row[-1] for row in fold]
predict_values = random_forest(train_set, test_set, ratio, n_features, max_depth, min_size, n_trees)
accur = accuracy(predict_values, actual)
scores.append(accur)
print ('Trees is %d' % n_trees)
print ('scores:%s' % scores)
print ('mean score:%s' % (sum(scores) / float(len(scores))))
2.5 隨機(jī)森林分類sonic data
# CART on the Bank Note dataset
from random import seed
from random import randrange
from csv import reader
# Load a CSV file
def load_csv(filename):
file = open(filename, "r")
lines = reader(file)
dataset = list(lines)
return dataset
# Convert string column to float
def str_column_to_float(dataset, column):
for row in dataset:
row[column] = float(row[column].strip())
# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):
dataset_split = list()
dataset_copy = list(dataset)
fold_size = int(len(dataset) / n_folds)
for i in range(n_folds):
fold = list()
while len(fold) < fold_size:
index = randrange(len(dataset_copy))
fold.append(dataset_copy.pop(index))
dataset_split.append(fold)
return dataset_split
# Calculate accuracy percentage
def accuracy_metric(actual, predicted):
correct = 0
for i in range(len(actual)):
if actual[i] == predicted[i]:
correct += 1
return correct / float(len(actual)) * 100.0
# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):
folds = cross_validation_split(dataset, n_folds)
scores = list()
for fold in folds:
train_set = list(folds)
train_set.remove(fold)
train_set = sum(train_set, [])
test_set = list()
for row in fold:
row_copy = list(row)
test_set.append(row_copy)
row_copy[-1] = None
predicted = algorithm(train_set, test_set, *args)
actual = [row[-1] for row in fold]
accuracy = accuracy_metric(actual, predicted)
scores.append(accuracy)
return scores
# Split a data set based on an attribute and an attribute value
def test_split(index, value, dataset):
left, right = list(), list()
for row in dataset:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right
# Calculate the Gini index for a split dataset
def gini_index(groups, class_values):
gini = 0.0
for class_value in class_values:
for group in groups:
size = len(group)
if size == 0:
continue
proportion = [row[-1] for row in group].count(class_value) / float(size)
gini += (proportion * (1.0 - proportion))
return gini
# Select the best split point for a dataset
def get_split(dataset):
class_values = list(set(row[-1] for row in dataset))
b_index, b_value, b_score, b_groups = 999, 999, 999, None
for index in range(len(dataset[0])-1):
for row in dataset:
groups = test_split(index, row[index], dataset)
gini = gini_index(groups, class_values)
if gini < b_score:
b_index, b_value, b_score, b_groups = index, row[index], gini, groups
print ({'index':b_index, 'value':b_value})
return {'index':b_index, 'value':b_value, 'groups':b_groups}
# Create a terminal node value
def to_terminal(group):
outcomes = [row[-1] for row in group]
return max(set(outcomes), key=outcomes.count)
# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
left, right = node['groups']
del(node['groups'])
# check for a no split
if not left or not right:
node['left'] = node['right'] = to_terminal(left + right)
return
# check for max depth
if depth >= max_depth:
node['left'], node['right'] = to_terminal(left), to_terminal(right)
return
# process left child
if len(left) <= min_size:
node['left'] = to_terminal(left)
else:
node['left'] = get_split(left)
split(node['left'], max_depth, min_size, depth+1)
# process right child
if len(right) <= min_size:
node['right'] = to_terminal(right)
else:
node['right'] = get_split(right)
split(node['right'], max_depth, min_size, depth+1)
# Build a decision tree
def build_tree(train, max_depth, min_size):
root = get_split(train)
split(root, max_depth, min_size, 1)
return root
# Make a prediction with a decision tree
def predict(node, row):
if row[node['index']] < node['value']:
if isinstance(node['left'], dict):
return predict(node['left'], row)
else:
return node['left']
else:
if isinstance(node['right'], dict):
return predict(node['right'], row)
else:
return node['right']
# Classification and Regression Tree Algorithm
def decision_tree(train, test, max_depth, min_size):
tree = build_tree(train, max_depth, min_size)
predictions = list()
for row in test:
prediction = predict(tree, row)
predictions.append(prediction)
return(predictions)
# Test CART on Bank Note dataset
seed(1)
# load and prepare data
filename = r'G:\0pythonstudy\決策樹\sonar.all-data.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])-1):
str_column_to_float(dataset, i)
# evaluate algorithm
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))
運(yùn)行結(jié)果:
{'index': 38, 'value': 0.0894}
{'index': 36, 'value': 0.8459}
{'index': 50, 'value': 0.0024}
{'index': 15, 'value': 0.0906}
{'index': 16, 'value': 0.9819}
{'index': 10, 'value': 0.0785}
{'index': 16, 'value': 0.0886}
{'index': 38, 'value': 0.0621}
{'index': 5, 'value': 0.0226}
{'index': 8, 'value': 0.0368}
{'index': 11, 'value': 0.0754}
{'index': 0, 'value': 0.0239}
{'index': 8, 'value': 0.0368}
{'index': 29, 'value': 0.1671}
{'index': 46, 'value': 0.0237}
{'index': 38, 'value': 0.0621}
{'index': 14, 'value': 0.0668}
{'index': 4, 'value': 0.0167}
{'index': 37, 'value': 0.0836}
{'index': 12, 'value': 0.0616}
{'index': 7, 'value': 0.0333}
{'index': 33, 'value': 0.8741}
{'index': 16, 'value': 0.0886}
{'index': 8, 'value': 0.0368}
{'index': 33, 'value': 0.0798}
{'index': 44, 'value': 0.0298}
Scores: [48.78048780487805, 70.73170731707317, 58.536585365853654, 51.2195121951
2195, 39.02439024390244]
Mean Accuracy: 53.659%
請(qǐng)按任意鍵繼續(xù). . .
知識(shí)點(diǎn):
1.load CSV file
from csv import reader
# Load a CSV file
def load_csv(filename):
file = open(filename, "r")
lines = reader(file)
dataset = list(lines)
return dataset
filename = r'G:\0pythonstudy\決策樹\sonar.all-data.csv'
dataset=load_csv(filename)
print(dataset)
2.把數(shù)據(jù)轉(zhuǎn)化成float格式
# Convert string column to float
def str_column_to_float(dataset, column):
for row in dataset:
row[column] = float(row[column].strip())
# print(row[column])
# convert string attributes to integers
for i in range(len(dataset[0])-1):
str_column_to_float(dataset, i)
3.把最后一列的分類字符串轉(zhuǎn)化成0、1整數(shù)
def str_column_to_int(dataset, column):
class_values = [row[column] for row in dataset]#生成一個(gè)class label的list
# print(class_values)
unique = set(class_values)#set 獲得list的不同元素
print(unique)
lookup = dict()#定義一個(gè)字典
# print(enumerate(unique))
for i, value in enumerate(unique):
lookup[value] = i
# print(lookup)
for row in dataset:
row[column] = lookup[row[column]]
print(lookup['M'])
4、把數(shù)據(jù)集分割成K份
# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):
dataset_split = list()#生成空列表
dataset_copy = list(dataset)
print(len(dataset_copy))
print(len(dataset))
#print(dataset_copy)
fold_size = int(len(dataset) / n_folds)
for i in range(n_folds):
fold = list()
while len(fold) < fold_size:
index = randrange(len(dataset_copy))
# print(index)
fold.append(dataset_copy.pop(index))#使用.pop()把里邊的元素都刪除(相當(dāng)于轉(zhuǎn)移),這k份元素各不相同。
dataset_split.append(fold)
return dataset_split
n_folds=5
folds = cross_validation_split(dataset, n_folds)#k份元素各不相同的訓(xùn)練集
5.計(jì)算正確率
# Calculate accuracy percentage
def accuracy_metric(actual, predicted):
correct = 0
for i in range(len(actual)):
if actual[i] == predicted[i]:
correct += 1
return correct / float(len(actual)) * 100.0#這個(gè)是二值分類正確性的表達(dá)式
6.二分類每列
# Split a data set based on an attribute and an attribute value
def test_split(index, value, dataset):
left, right = list(), list()#初始化兩個(gè)空列表
for row in dataset:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right #返回兩個(gè)列表,每個(gè)列表以value為界限對(duì)指定行(index)進(jìn)行二分類。
7.使用gini系數(shù)來(lái)獲得最佳分割點(diǎn)
# Calculate the Gini index for a split dataset
def gini_index(groups, class_values):
gini = 0.0
for class_value in class_values:
for group in groups:
size = len(group)
if size == 0:
continue
proportion = [row[-1] for row in group].count(class_value) / float(size)
gini += (proportion * (1.0 - proportion))
return gini
# Select the best split point for a dataset
def get_split(dataset):
class_values = list(set(row[-1] for row in dataset))
b_index, b_value, b_score, b_groups = 999, 999, 999, None
for index in range(len(dataset[0])-1):
for row in dataset:
groups = test_split(index, row[index], dataset)
gini = gini_index(groups, class_values)
if gini < b_score:
b_index, b_value, b_score, b_groups = index, row[index], gini, groups
# print(groups)
print ({'index':b_index, 'value':b_value,'score':gini})
return {'index':b_index, 'value':b_value, 'groups':b_groups}
這段代碼,在求gini指數(shù),直接應(yīng)用定義式,不難理解。獲得最佳分割點(diǎn)可能比較難看懂,這里用了兩層迭代,一層是對(duì)不同列的迭代,一層是對(duì)不同行的迭代。并且,每次迭代,都對(duì)gini系數(shù)進(jìn)行更新。
8、決策樹生成
# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
left, right = node['groups']
del(node['groups'])
# check for a no split
if not left or not right:
node['left'] = node['right'] = to_terminal(left + right)
return
# check for max depth
if depth >= max_depth:
node['left'], node['right'] = to_terminal(left), to_terminal(right)
return
# process left child
if len(left) <= min_size:
node['left'] = to_terminal(left)
else:
node['left'] = get_split(left)
split(node['left'], max_depth, min_size, depth+1)
# process right child
if len(right) <= min_size:
node['right'] = to_terminal(right)
else:
node['right'] = get_split(right)
split(node['right'], max_depth, min_size, depth+1)
這里使用了遞歸編程,不斷生成左叉樹和右叉樹。
9.構(gòu)建決策樹
# Build a decision tree
def build_tree(train, max_depth, min_size):
root = get_split(train)
split(root, max_depth, min_size, 1)
return root
tree=build_tree(train_set, max_depth, min_size)
print(tree)
10、預(yù)測(cè)test集
# Build a decision tree
def build_tree(train, max_depth, min_size):
root = get_split(train)#獲得最好的分割點(diǎn),下標(biāo)值,groups
split(root, max_depth, min_size, 1)
return root
# tree=build_tree(train_set, max_depth, min_size)
# print(tree)
# Make a prediction with a decision tree
def predict(node, row):
print(row[node['index']])
print(node['value'])
if row[node['index']] < node['value']:#用測(cè)試集來(lái)代入訓(xùn)練的最好分割點(diǎn),分割點(diǎn)有偏差時(shí),通過(guò)搜索左右叉樹來(lái)進(jìn)一步比較。
if isinstance(node['left'], dict):#如果是字典類型,執(zhí)行操作
return predict(node['left'], row)
else:
return node['left']
else:
if isinstance(node['right'], dict):
return predict(node['right'], row)
else:
return node['right']
tree = build_tree(train_set, max_depth, min_size)
predictions = list()
for row in test_set:
prediction = predict(tree, row)
predictions.append(prediction)
11.評(píng)價(jià)決策樹
# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):
folds = cross_validation_split(dataset, n_folds)
scores = list()
for fold in folds:
train_set = list(folds)
train_set.remove(fold)
train_set = sum(train_set, [])
test_set = list()
for row in fold:
row_copy = list(row)
test_set.append(row_copy)
row_copy[-1] = None
predicted = algorithm(train_set, test_set, *args)
actual = [row[-1] for row in fold]
accuracy = accuracy_metric(actual, predicted)
scores.append(accuracy)
return scores
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助
數(shù)據(jù)分析咨詢請(qǐng)掃描二維碼
若不方便掃碼,搜微信號(hào):CDAshujufenxi
SQL Server 中 CONVERT 函數(shù)的日期轉(zhuǎn)換:從基礎(chǔ)用法到實(shí)戰(zhàn)優(yōu)化 在 SQL Server 的數(shù)據(jù)處理中,日期格式轉(zhuǎn)換是高頻需求 —— 無(wú)論 ...
2025-09-18MySQL 大表拆分與關(guān)聯(lián)查詢效率:打破 “拆分必慢” 的認(rèn)知誤區(qū) 在 MySQL 數(shù)據(jù)庫(kù)管理中,“大表” 始終是性能優(yōu)化繞不開的話題。 ...
2025-09-18CDA 數(shù)據(jù)分析師:表結(jié)構(gòu)數(shù)據(jù) “獲取 - 加工 - 使用” 全流程的賦能者 表結(jié)構(gòu)數(shù)據(jù)(如數(shù)據(jù)庫(kù)表、Excel 表、CSV 文件)是企業(yè)數(shù)字 ...
2025-09-18DSGE 模型中的 Et:理性預(yù)期算子的內(nèi)涵、作用與應(yīng)用解析 動(dòng)態(tài)隨機(jī)一般均衡(Dynamic Stochastic General Equilibrium, DSGE)模 ...
2025-09-17Python 提取 TIF 中地名的完整指南 一、先明確:TIF 中的地名有哪兩種存在形式? 在開始提取前,需先判斷 TIF 文件的類型 —— ...
2025-09-17CDA 數(shù)據(jù)分析師:解鎖表結(jié)構(gòu)數(shù)據(jù)特征價(jià)值的專業(yè)核心 表結(jié)構(gòu)數(shù)據(jù)(以 “行 - 列” 規(guī)范存儲(chǔ)的結(jié)構(gòu)化數(shù)據(jù),如數(shù)據(jù)庫(kù)表、Excel 表、 ...
2025-09-17Excel 導(dǎo)入數(shù)據(jù)含缺失值?詳解 dropna 函數(shù)的功能與實(shí)戰(zhàn)應(yīng)用 在用 Python(如 pandas 庫(kù))處理 Excel 數(shù)據(jù)時(shí),“缺失值” 是高頻 ...
2025-09-16深入解析卡方檢驗(yàn)與 t 檢驗(yàn):差異、適用場(chǎng)景與實(shí)踐應(yīng)用 在數(shù)據(jù)分析與統(tǒng)計(jì)學(xué)領(lǐng)域,假設(shè)檢驗(yàn)是驗(yàn)證研究假設(shè)、判斷數(shù)據(jù)差異是否 “ ...
2025-09-16CDA 數(shù)據(jù)分析師:掌控表格結(jié)構(gòu)數(shù)據(jù)全功能周期的專業(yè)操盤手 表格結(jié)構(gòu)數(shù)據(jù)(以 “行 - 列” 存儲(chǔ)的結(jié)構(gòu)化數(shù)據(jù),如 Excel 表、數(shù)據(jù) ...
2025-09-16MySQL 執(zhí)行計(jì)劃中 rows 數(shù)量的準(zhǔn)確性解析:原理、影響因素與優(yōu)化 在 MySQL SQL 調(diào)優(yōu)中,EXPLAIN執(zhí)行計(jì)劃是核心工具,而其中的row ...
2025-09-15解析 Python 中 Response 對(duì)象的 text 與 content:區(qū)別、場(chǎng)景與實(shí)踐指南 在 Python 進(jìn)行 HTTP 網(wǎng)絡(luò)請(qǐng)求開發(fā)時(shí)(如使用requests ...
2025-09-15CDA 數(shù)據(jù)分析師:激活表格結(jié)構(gòu)數(shù)據(jù)價(jià)值的核心操盤手 表格結(jié)構(gòu)數(shù)據(jù)(如 Excel 表格、數(shù)據(jù)庫(kù)表)是企業(yè)最基礎(chǔ)、最核心的數(shù)據(jù)形態(tài) ...
2025-09-15Python HTTP 請(qǐng)求工具對(duì)比:urllib.request 與 requests 的核心差異與選擇指南 在 Python 處理 HTTP 請(qǐng)求(如接口調(diào)用、數(shù)據(jù)爬取 ...
2025-09-12解決 pd.read_csv 讀取長(zhǎng)浮點(diǎn)數(shù)據(jù)的科學(xué)計(jì)數(shù)法問(wèn)題 為幫助 Python 數(shù)據(jù)從業(yè)者解決pd.read_csv讀取長(zhǎng)浮點(diǎn)數(shù)據(jù)時(shí)的科學(xué)計(jì)數(shù)法問(wèn)題 ...
2025-09-12CDA 數(shù)據(jù)分析師:業(yè)務(wù)數(shù)據(jù)分析步驟的落地者與價(jià)值優(yōu)化者 業(yè)務(wù)數(shù)據(jù)分析是企業(yè)解決日常運(yùn)營(yíng)問(wèn)題、提升執(zhí)行效率的核心手段,其價(jià)值 ...
2025-09-12用 SQL 驗(yàn)證業(yè)務(wù)邏輯:從規(guī)則拆解到數(shù)據(jù)把關(guān)的實(shí)戰(zhàn)指南 在業(yè)務(wù)系統(tǒng)落地過(guò)程中,“業(yè)務(wù)邏輯” 是連接 “需求設(shè)計(jì)” 與 “用戶體驗(yàn) ...
2025-09-11塔吉特百貨孕婦營(yíng)銷案例:數(shù)據(jù)驅(qū)動(dòng)下的精準(zhǔn)零售革命與啟示 在零售行業(yè) “流量紅利見頂” 的當(dāng)下,精準(zhǔn)營(yíng)銷成為企業(yè)突圍的核心方 ...
2025-09-11CDA 數(shù)據(jù)分析師與戰(zhàn)略 / 業(yè)務(wù)數(shù)據(jù)分析:概念辨析與協(xié)同價(jià)值 在數(shù)據(jù)驅(qū)動(dòng)決策的體系中,“戰(zhàn)略數(shù)據(jù)分析”“業(yè)務(wù)數(shù)據(jù)分析” 是企業(yè) ...
2025-09-11Excel 數(shù)據(jù)聚類分析:從操作實(shí)踐到業(yè)務(wù)價(jià)值挖掘 在數(shù)據(jù)分析場(chǎng)景中,聚類分析作為 “無(wú)監(jiān)督分組” 的核心工具,能從雜亂數(shù)據(jù)中挖 ...
2025-09-10統(tǒng)計(jì)模型的核心目的:從數(shù)據(jù)解讀到?jīng)Q策支撐的價(jià)值導(dǎo)向 統(tǒng)計(jì)模型作為數(shù)據(jù)分析的核心工具,并非簡(jiǎn)單的 “公式堆砌”,而是圍繞特定 ...
2025-09-10