赞
踩
在keras环境中,使用多GPU进行训练,但是如何使保存的模型为能在单GPU上运行的模型呢?4块GPU环境下训练的模型,放到其他的机器上,那么也必须使用4GPU的机器才行。
偷梁换柱!!
即在多GPU的环境下加载权重(或者模型),再保存单GPU模型。
前提条件,此时我们已经保存了最优的模型或者仅仅是模型参数:multi_model.h5
流程如下:

keras自带模块 multi_gpu_model,此方式为数据并行的方式,将将目标模型在多个设备上各复制一份,并使用每个设备上的复制品处理整个数据集的不同部分数据,最高支持在8片GPU上并行。
使用方式:
def get_model(input_shape): . . . return model model = get_model(input_shape) #此时为单GPU 搭建的model from keras.utils import multi_gpu_model # Replicates `model` on 4 GPUs. # This assumes that your machine has 4 available GPUs. model = multi_gpu_model(model, gpus=4) #将搭建的model复制到4个GPU中 # for train model.compile(loss='categorical_crossentropy', optimizer='adam') # fit data for train
因为定义的多核训练,所以网络的每一层都是按GPU来命名的,训练时采用多个GPU那么当导入参数的时候必须指定相同数量的GPU才行,如上代码的指定方式。但是,但我们将model切换到单GPU的环境中时,则会出现错误,此时我们必须将参数保存为单GPU的形式。
方法:
在原多GPU环境中导入模型,保存为单GPU版本,修改训练代码(fit),改为加载已经训练的权重。
此时训练已经结束。
def get_model(input_shape): . . . return model model = get_model(input_shape) #此时为单GPU 搭建的model # metric # loss from keras.utils import multi_gpu_model paralleled_model = multi_gpu_model(model,gpus=4) # 此时paralleled_model为4个GPU的模型,已经进行复制,但是seg_model仍然为单GPU model。 #seg_model.compile(optimizer=Adam) #训练结束,注释掉 paralleled_model.load_weights("multi_model.h5") # 加载之前训练保存的在多GPU上训练的模型参数 model.save('single_gpu_model.h5') # 保存单GPU的模型seg_model此时,保存的就是单模型参数!!
加载单GPU模型:
model.load_weights("single_gpu_model.h5")
def get_model(input_shape): . . . return model model = get_model(input_shape) #此时为单GPU 搭建的model from keras.utils import multi_gpu_model # Replicates `model` on 4 GPUs. # This assumes that your machine has 4 available GPUs. paralleled_model = multi_gpu_model(model, gpus=4) #将搭建的model复制到4个GPU中 # for train paralleled_model.compile(loss='categorical_crossentropy', optimizer='adam') model.save_weights("single_gpu_model.h5") # fit data for train
使用多个GPU训练模型,使用multi_gpu_model和ModelCheckpoint来保存最佳模型,则在检查点保存的模型上调用load_model时会出现此错误:[4]
"ValueError: axes don't match array"
具体问题记录如下:

目前只能通过以上两种方式解决。
但是,之前训练Unet时并未出现这个问题。。。。。。。
-----2019-01-24
参考:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。