赞
踩
主要需要xgboost4j-spark-0.90.jar, xgboost4j-0.90.jar, 以及 调用代码 sparkxgb.zip.
GitHub上面有xgboost java 实现的包,链接:xgboost;
但我省事,用了zhihu xgboost的分布式版本(pyspark)使用测试 的下载链接。
注意,xgboost 的版本号 和sparkxgb内的内容对应。
我是使用pyspark 运行,通过 pyspark --jars **
把用到的这两个jar包引入。
#!/usr/bin/env python # -*- coding:utf8 -*- import os import sys import time import pandas as pd import numpy as np from pyspark import SparkConf, SparkContext import pyspark.sql.types as typ import pyspark.ml.feature as ft from pyspark.sql.functions import isnan, isnull,col import pyspark from pyspark.sql.session import SparkSession from pyspark.sql import SQLContext from pyspark.sql.types import * from pyspark.ml.feature import StringIndexer,VectorAssembler from pyspark.ml.linalg import Vectors from pyspark.ml import Pipeline from sparkxgb import XGBoostClassifier import sklearn.datasets as datasets import numpy as np import time def normalize(x): return (x - np.min(x)) / (np.max(x) - np.min(x)) def get_data(): # input datasets X, y = datasets.make_blobs(n_samples=100000, centers=10, n_features=10, random_state=0) # 归一化 X_norm = normalize(X) X_train = X_norm[:int(len(X_norm) * 0.8)] X_test = X_norm[int(len(X_norm) * 0.8):] y_train = y[:int(len(X_norm) * 0.8)] y_test = y[int(len(X_norm) * 0.8):] y_train = y_train.reshape(-1, 1) # spark df df = np.concatenate([y_train, X_train], axis=1) train_df = map(lambda x: (int(x[0]), Vectors.dense(x[1:])), df) spark_train = spark.createDataFrame(train_df, schema=["label", "features"]) test_df = map(lambda x: (Vectors.dense(x),), X_test) spark_test = spark.createDataFrame(test_df, schema=["features"]) return spark_train,spark_test,y_train,y_test def train_model(trainDF): xgboost = XGBoostClassifier( featuresCol="features", labelCol="label", predictionCol="prediction", objective='multi:softprob', numClass=10, missing=0.0 ) pipeline = Pipeline(stages=[xgboost]) model = pipeline.fit(trainDF) # # Write model/classifier # model.write().overwrite().save(hdfstrainpth + "/xgboost_class_test") # model.load(hdfstrainpth + "/xgboost_class_test") return model def test(): data = [1, 2, 3, 4, 5] distData = sc.parallelize(data) print("done", distData.collect()) def cal_acc(pred, true): count = 0 for i,row in enumerate(pred): pred = row if pred == true[i]: count += 1 acc = round(count/len(true), 4) return acc if __name__ == "__main__": from pyspark import SparkContext conf = SparkConf().set("spark.jars", "/home/xgboost4j-0.90.jar,/home/xgboost4j-spark-0.90.jar") sc = SparkContext(conf=conf).getOrCreate() spark = SQLContext(sc) trainDf, testDf,y_train,y_test = get_data() print('get df') model = train_model(trainDf) prediction = model.transform(testDf).select("prediction").collect() acc = cal_acc(prediction, y_test) print("acc:{}".format(acc))
运行结果:acc:0.9992
预测结果:
model.transform(testDf).show()
+--------------------+--------------------+--------------------+----------+ | features| rawPrediction| probability|prediction| +--------------------+--------------------+--------------------+----------+ |[0.36383649267021...|[0.33353492617607...|[0.06999947130680...| 9.0| |[0.85080275306445...|[0.33345550298690...|[0.06996602565050...| 2.0| |[0.54471116142668...|[1.99881935119628...|[0.37008801102638...| 0.0| |[0.61089833342796...|[0.33345550298690...|[0.06995990127325...| 5.0| |[0.25437385667790...|[0.33415806293487...|[0.07003305852413...| 6.0| |[0.47371795998355...|[1.99881935119628...|[0.37008947134017...| 0.0| |[0.75258857302126...|[0.33345550298690...|[0.07017561793327...| 2.0| |[0.38430822786126...|[0.33345550298690...|[0.06999430805444...| 9.0| |[0.84192691973241...|[0.33345550298690...|[0.06999272853136...| 7.0| |[0.89822104638187...|[0.33345550298690...|[0.06999462842941...| 2.0| |[0.87335367752325...|[0.33345550298690...|[0.06999401748180...| 2.0| |[0.34598394310439...|[0.33365276455879...|[0.07000749558210...| 9.0| |[0.37907532566580...|[0.33345550298690...|[0.06999314576387...| 8.0| |[0.85996665363900...|[0.33345550298690...|[0.06998810172080...| 7.0| |[0.52503470825319...|[1.99881935119628...|[0.37008947134017...| 0.0| |[0.51847376135870...|[0.33345550298690...|[0.06998340785503...| 5.0| |[0.51366954373353...|[1.98586511611938...|[0.36707320809364...| 0.0| |[0.38344970186248...|[0.33345550298690...|[0.06998835504055...| 4.0| |[0.31206934826790...|[0.33353492617607...|[0.06996974349021...| 6.0| |[0.68235540326326...|[0.33345550298690...|[0.06998881697654...| 1.0| +--------------------+--------------------+--------------------+----------+
参考:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。