当前位置:   article > 正文

Sklearn的决策树算法实现鸢尾花分类_基于sklearn决策树算法对鸢尾花数据进行分类

基于sklearn决策树算法对鸢尾花数据进行分类

Iris Data Set

Iris Data Set(鸢尾属植物数据集)是历史比较悠久的数据集,它首次出现在著名的英国统计学家和生物学家Ronald Fisher 1936年的论文《The use of multiple measurements in taxonomic problems》中,被用来介绍线性判别式分析。在这个数据集中,包括了三类不同的鸢尾属植物:Iris Setosa,Iris Versicolour,Iris Virginica。每类收集了50个样本,因此这个数据集一共包含了150个样本。
该数据集测量了所有150个样本的4个特征,分别是:sepal length(花萼长度)、sepal width(花萼宽度)、petal length(花瓣长度)、petal width(花瓣宽度),以上四个特征的单位都是厘米。

Sklearn的决策树算法

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier,plot_tree

# Parameters
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02

plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi'] = 300 #分辨率

# Load Data
iris = load_iris() # 150,4
for pairidx,pair in enumerate([[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]]):
    # We only take the two corresponding features
    X = iris.data[:,pair]
    y = iris.target
    
    # Train
    clf = DecisionTreeClassifier().fit(X,y)
    
    # Plot the decision boundary
    plt.subplot(2,3,pairidx + 1)
    x_min,x_max = X[:,0].min() -1,X[:,0].max() +1
    y_min,y_max = X[:,1].min() -1,X[:,1].max() +1
    xx, yy = np.meshgrid(np.arange(x_min,x_max,plot_step),    np.arange(y_min,y_max,plot_step)) # 生成网格点坐标矩阵
    plt.tight_layout(h_pad=0.5,w_pad=0.5,pad=2.5)
    
    #np.r_是按列连接两个矩阵,就是把两矩阵上下相加,要求列数相等。
    #np.c_是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等。
    Z = clf.predict(np.c_[xx.ravel(),yy.ravel()]) # 将多维数组转换为一维数组
    #print(xx.shape,yy.shape)
    Z = Z.reshape(xx.shape)
    cs = plt.contour(xx,yy,Z,cmap=plt.cm.RdYlBu)
    
    plt.xlabel(iris.feature_names[pair[0]])
    plt.ylabel(iris.feature_names[pair[1]])
    
    # Plot the training points
    for i, color in zip(range(n_classes),plot_colors):
        idx = np.where(y == i)
        plt.scatter(X[idx, 0],X[idx, 1],c=color, label=iris.target_names[i],
                   cmap=plt.cm.RdYlBu,edgecolor="black",s=15)   
plt.suptitle("Decision surface of a decesion tree using paired features")
plt.legend(loc="lower right",borderpad=0)
plt.axis("tight")

plt.figure()
clf = DecisionTreeClassifier().fit(iris.data,iris.target)
plot_tree(clf,filled=True)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

两两特征之间的决策树边界在这里插入图片描述

决策树可视化
在这里插入图片描述

参考:https://scikit-learn.org/dev/modules/tree.html

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/701816
推荐阅读
相关标签
  

闽ICP备14008679号