当前位置:   article > 正文

Python训练的机器学习模型【保存】 和【加载】的方法?_python保存模型

python保存模型

一.为什么要保存训练好的模型

        由于传统训练机器学习模型,需要耗费大量的人力和资源。因此,将训练好的模型保存成为一件特别重要的事情。

        现有的机器学习模型保存方法有三种,分别为使用pickle(通用)、joblib(大型模型)、HDF5(存储深度学习模型的权重)

二.Python保存模型的三种方式

1.方式一:pickle模块【通用】
        pickle是Python标准库中的一个模块,它可以将Python对象序列化为二进制格式,以便于存储和传输。可以使用pickle将训练好的模型保存到磁盘上,以备将来使用。

1)保存模型

import pickle

# 训练模型
model = ... 

# 保存模型
with open('model.pkl', 'wb') as file:
    pickle.dump(model, file)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2)加载模型

# 加载模型
with open('model.pkl', 'rb') as file:
    model = pickle.load(file)
  • 1
  • 2
  • 3

2.方式二:joblib模块【大型模型】
        joblib是一个用于将Python对象序列化为磁盘文件的库,专门用于 大型数组。它可以高效地处理大型数据集和模型。对于大型机器学习模型,使用joblib可能比pickle更快。

1)保存模型

import joblib

# train the model
model = ...

# save the model
joblib.dump(model, 'my_model.joblib')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2)加载模型

import joblib

# load the saved model
model = joblib.load('my_model.joblib')

# predict using the loaded model
model.predict(X)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

3.方式三:使用HDF5——大型模型(保存权重)

HDF5是一种用于存储大型科学数据集的文件格式,常用于存储深度学习模型的权重。

1)保存模型(权重)

# train the model
model = ...

# save the model weights to HDF5 file
model.save_weights('my_model_weights.h5')
  • 1
  • 2
  • 3
  • 4
  • 5

2)使用模型

import tensorflow as tf

# define the model architecture
model = ...

# load the saved model weights
model.load_weights('my_model_weights.h5')

# predict using the loaded model weights
model.predict(X)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在这个例子中,我们首先定义了模型的架构,然后使用model.load_weights()函数加载之前保存的模型权重,并在新数据上进行预测。

4.方式四:使用ONNX——不同平台
ONNX是一种开放式的格式,可以用于表示机器学习模型。使用ONNX,您可以将模型从一个框架转换为另一个框架,或者在不同平台上使用模型。
1)保存模型

import onnx

# train the model
model = ...

# convert the model to ONNX format
onnx_model = onnx.convert(model)

# save the model
onnx.save_model(onnx_model, 'my_model.onnx')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2)加载(使用)模型

import onnxruntime

# load the saved model
onnx_session = onnxruntime.InferenceSession('my_model.onnx')

# predict using the loaded model
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
result = onnx_session.run([output_name], {input_name: X})
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

参考链接:https://blog.csdn.net/qq_22841387/article/details/130194553

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/818313
推荐阅读
相关标签
  

闽ICP备14008679号