赞
踩
本笔记主要以pytorch lightning中的ModelCheckpoint接口解析pytorch lightning中模型的保存方式
该类通过监控设置的metric定期保存模型,LightningModule 中使用 log() 或 log_dict() 记录的每个metric都是监控对象的候选者;更多的信息可以进入此链接浏览。训练完成后,在日志中使用 best_model_path 检索最佳checkpoint的路径,使用 best_model_score 检索其分数
pytorch_lightning.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode=‘min’, auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None)
# custom path,自定义路径
# saves a file like: my/path/epoch=0-step=10.ckpt,文件名会自动由epoch和step构成
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
# save any arbitrary metrics like `val_loss`, etc. in name,将感兴趣的metrics保存在文件名中
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
checkpoint_callback = ModelCheckpoint(
... dirpath='my/path',
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard # or Neptune, due to the presence of characters like '=' or '/') # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val/loss', ... dirpath='my/path/', ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', ... auto_insert_metric_name=False ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path # 直接获取最好的模型保存的路径
同时保存和恢复多个checkpoint的回调是支持的,可浏览官方文档学习使用
model = Pytorch_Lightning_Model(args)
train.fit(model)
train.save_checkpoint(example.ckpt)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。