当前位置:   article > 正文

API详解:sklearn.pipeline.Pipeline_sklearn pipeline api

sklearn pipeline api

sklearn中提供了Pipeline(管道操作)

可以将多个estimators组装成一个。对于固定流程的一个项目来说,在一个Pipline中可以定义一些列的操作例如(特征提取,标准化,分类)并将它定义成一个estimator,实现便捷的代码附用。

总的来说pipline的意义有:

便捷:只需要使用fit和predict两个methods,就可以基于定义好的Pipeline对数据进行一系列的操作。这样的做法,方便了对超参的选择

安全:Pipeline 能够保证相同的样本被用于数据处理和预测

在Pipeline中,中间步骤必须是变换操作(transform),至少含有一个transform的method,最后的estimator可以是任意形式
pipeline的目的是组装一些列的操作,在cross-validated过程中,找到最好的超参。

调用方式

Pipeline(memory=None,steps=list)

memory: 默认为None 可以是str或者是joblib.Memory interface.
用于缓存fitted 好的transformers 默认情况下没有缓存。 如果给定字符串,那么该字符串是缓存地址。 给定缓存时,在fit之前复制transformers,因此在transformer的实例中不能被直接查到。可以用name_steps 或者 steps检查pipe中的estimators. 当fit耗时的时候,将transformer放在缓存中更有利。

list: 一系列的tuple(name,transform)在list中定义一些列的操作

可查的属性

named_steps : keys是step names 返还的值是设定的参数值

举个例子说明下用法

import numpy as np
import pandas as pd 
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import matplotlib as mpl
import matplotlib.pyplot as plt

## 设置字符集,防止中文乱码
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False

# 定义目标函数
def l_model(x):
    params = np.arange(1,x.shape[-1]+3)
    y = np.sum(params[:-2]*x)+np.random.randn(1)*0.1+5*params[-2]*x[0]*x[1]+5*params[-1]*x[1]*x[2]
    return y

# 定义数据集
x = pd.DataFrame(np.random.rand(500,6))
y = x.apply(lambda x_rows:pd.Series(l_model(x_rows)),axis=1)

# 划分训练集和测试集
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.3,random_state=2)

# 定义管道,在models中可以定义多个Pipeline
models = [
    Pipeline(memory=None,
            steps=[
            ('StandardScaler',StandardScaler()), #数据标准化
            ('Poly',PolynomialFeatures()), #多项式扩展
            ('LinearRegression',LinearRegression()), #线性回归
        ])
]
model = models[0]
print(model)

"""
Pipeline(memory=None,
     steps=[('StandardScaler', StandardScaler(copy=True, with_mean=True, with_std=True)), ('Poly', PolynomialFeatures(degree=2, include_bias=True, interaction_only=False)), ('LinearRegression', LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False))])
"""
# 定义要遍历的参数
t = np.arange(len(x_test))
N=4
scale_pool = [True,False]
degree_pool = np.arange(1,N,1)
regressor_pool = [True,False]
gsize = len(scale_pool)*len(degree_pool)*len(regressor_pool)

# 管道参数遍历训练模型
line_width=3
plt.figure(figsize=(12,15),facecolor='w')#创建一个绘图窗口,设置大小,设置颜色
ical = 1 
for i,s in enumerate(scale_pool):
    for j,d in enumerate(degree_pool):
        for k,r in enumerate(regressor_pool):
            plt.subplot(gsize,1,ical)
            plt.plot(t, y_test, 'r-', label=u'真实值')
            # 设置管道参数
            model.set_params(StandardScaler__with_mean=s) # 标准化的时候是否要中心化
            model.set_params(Poly__degree=d) # 多项式拓展的阶数
            model.set_params(LinearRegression__fit_intercept=r) # 回归的时候是否考虑截距 
            ical +=1
            # 训练
            model.fit(x_train,y_train)
            # 预测
            y_predict = model.predict(x_test)
            # 评估
            score = model.score(x_test,y_test)
            # 画图
            label = u'%d阶, 准确率=%.3f,中心化=%s,截距=%s' % (d,score,s,r)
            plt.plot(t, y_predict, 'b-', lw=line_width, alpha=0.75,label=label)
            plt.legend(loc = 'upper left')
            plt.grid(True)
plt.suptitle(u"Pipeline参数对比", fontsize=20)
plt.grid(b=True)
plt.show()
plt.savefig('pipeline.png')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80

效果对比图
使用下来的感受是,所有的过程,只需要一个fit和一个predict就可以实现。大大减少了代码量。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/889615
推荐阅读
相关标签
  

闽ICP备14008679号