赞
踩
Iris Data Set(鸢尾属植物数据集)是历史比较悠久的数据集,它首次出现在著名的英国统计学家和生物学家Ronald Fisher 1936年的论文《The use of multiple measurements in taxonomic problems》中,被用来介绍线性判别式分析。在这个数据集中,包括了三类不同的鸢尾属植物:Iris Setosa,Iris Versicolour,Iris Virginica。每类收集了50个样本,因此这个数据集一共包含了150个样本。
该数据集测量了所有150个样本的4个特征,分别是:sepal length(花萼长度)、sepal width(花萼宽度)、petal length(花瓣长度)、petal width(花瓣宽度),以上四个特征的单位都是厘米。
import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier,plot_tree # Parameters n_classes = 3 plot_colors = "ryb" plot_step = 0.02 plt.rcParams['savefig.dpi'] = 300 #图片像素 plt.rcParams['figure.dpi'] = 300 #分辨率 # Load Data iris = load_iris() # 150,4 for pairidx,pair in enumerate([[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]]): # We only take the two corresponding features X = iris.data[:,pair] y = iris.target # Train clf = DecisionTreeClassifier().fit(X,y) # Plot the decision boundary plt.subplot(2,3,pairidx + 1) x_min,x_max = X[:,0].min() -1,X[:,0].max() +1 y_min,y_max = X[:,1].min() -1,X[:,1].max() +1 xx, yy = np.meshgrid(np.arange(x_min,x_max,plot_step), np.arange(y_min,y_max,plot_step)) # 生成网格点坐标矩阵 plt.tight_layout(h_pad=0.5,w_pad=0.5,pad=2.5) #np.r_是按列连接两个矩阵,就是把两矩阵上下相加,要求列数相等。 #np.c_是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等。 Z = clf.predict(np.c_[xx.ravel(),yy.ravel()]) # 将多维数组转换为一维数组 #print(xx.shape,yy.shape) Z = Z.reshape(xx.shape) cs = plt.contour(xx,yy,Z,cmap=plt.cm.RdYlBu) plt.xlabel(iris.feature_names[pair[0]]) plt.ylabel(iris.feature_names[pair[1]]) # Plot the training points for i, color in zip(range(n_classes),plot_colors): idx = np.where(y == i) plt.scatter(X[idx, 0],X[idx, 1],c=color, label=iris.target_names[i], cmap=plt.cm.RdYlBu,edgecolor="black",s=15) plt.suptitle("Decision surface of a decesion tree using paired features") plt.legend(loc="lower right",borderpad=0) plt.axis("tight") plt.figure() clf = DecisionTreeClassifier().fit(iris.data,iris.target) plot_tree(clf,filled=True) plt.show()
两两特征之间的决策树边界
决策树可视化
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。