当前位置:   article > 正文

keras保存训练好模型的方法_keras保存训练好的模型

keras保存训练好的模型

前言:本博客讲的比较基础,只适合刚入坑的小伙伴们阅读。

  一.加载好数据

  读者可以自己加载好任意数据集,笔者这里使用的是mnist数据集。

  1. '''
  2. 制作人:追天一方
  3. 功能:keras模型的保存方法示例
  4. 笔者水平有限错误之处请包容
  5. '''
  6. import os
  7. import tensorflow as tf
  8. from tensorflow import keras
  9. #示例数据集是mnist,只使用前一千个示例
  10. (train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.mnist.load_data()
  11. #选取前一千个并且预处理
  12. train_labels=train_labels[:1000]
  13. test_labels=test_labels[:1000]
  14. train_images=train_images[:1000].reshape(-1,28*28)/255.0
  15. test_images=test_images[:1000].reshape(-1,28*28)/255.0

二.创建一个模型。

  笔者创建了一个简单的模型,读者也可以根据自己的需求创造模型。

  1. #定义一个模型
  2. def create_model():
  3. model=tf.keras.models.Sequential([
  4. keras.layers.Dense(512,activation='relu',input_shape=(784,)),
  5. keras.layers.Dropout(0.2),
  6. keras.layers.Dense(10)
  7. ])
  8. model.compile(optimizer='adam',
  9. loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
  10. metrics=[tf.metrics.SparseCategoricalAccuracy()])
  11. return model
  12. #实列化模型
  13. model=create_model()
  14. #查看模型结构和参数
  15. print(model.summary())

三.保存模型

  keras可以通过tf.keras.callbacks.ModelCheckpoint创建回调函数,然后将回调传入model.fit()函数当中,就可以实现间隔一定的训练批次保存模型或者模型权重,将tf.keras.callbacks.ModelCheckpoint函数中的save_weights_only参数设置为True表示只保存权重,反之则会保存整个模型。笔者代码示例是保存模型权重,如下:
  1. checkpoint_path2=r'.\model_test\cp-{epoch:04d}.ckpt'
  2. checkpoint_dir2=os.path.dirname(checkpoint_path2)
  3. batch_size=32
  4. #创建五个批次保存一次模型参数的回调
  5. cp_callback2=tf.keras.callbacks.ModelCheckpoint(
  6. filepath=checkpoint_path2,
  7. verbose=1,
  8. save_weights_only=True,
  9. save_freq=5*batch_size)
  10. #实例化模型
  11. model=create_model()
  12. #使用checkpoint_path保存模型权重
  13. model.save_weights(checkpoint_path2.format(epoch=0))
  14. model.fit(train_images,train_labels,epochs=20,callbacks=[cp_callback2],
  15. validation_data=(test_images,test_labels),verbose=0)
  16. # latest=tf.train.latest_checkpoint(checkpoint_dir2)
  17. #加载模型评估
  18. model=create_model()
  19. #加载权重
  20. model.load_weights(r'.\model_test\cp-0020.ckpt')
  21. #评估模型
  22. loss,acc=model.evaluate(test_images,test_labels,verbose=2)
  23. print("准确率:{:5.2f}%".format(100*acc))

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/53755?site
推荐阅读
相关标签
  

闽ICP备14008679号