当前位置:   article > 正文

TensorFlow应用.pb文件保存和加载模型方法及相关注意事项_加载pb模型

加载pb模型

其他参考链接:https://blog.csdn.net/guyuealian/article/details/82218092

 

一、.ckpt转.pb用于模型上线

.ckpt转.pb主要应用于将训练模型发布上线,.pb模型的跨平台和跨框架性能更好。这里由于在保存.pb模型前需要将模型变量freezing。在应用tensorflow训练模型时,输入数据的batch_size>1,直接保存.pb模型时会在inference阶段出现问题,所以需要从.ckpt转为.pb。在加载.ckpt时可以重新定义输入数据的batch_size=1,以解决该问题。应用步骤主要分为:

1.加载.ckpt并且.ckpt转.pb:

(1)定义图模型,在inference阶段加载.ckpt文件:

  1. sample_graph = tf.Graph()
  2. with sample_graph.as_default():
  3. input_data = tf.placeholder(tf.float32, shape=input_shape, name='input_data_gt') # input_shape的batch_size维度为1
  4. output = sample_net(input_data)
  5. net_saver = tf.train.Saver()
  6. sess = tf.Session(graph = sample_graph)
  7. net_saver.restore(sess, model_path)

首先定义输入数据变量placehold数据类型,这里定义输入数据变量时将batch_size=1,再加载图模型,最后restore .ckpt模型。

(2).ckpt转.pb:

  1. tf.train.write_graph(sess.graph_def, pb_dir, pb_name)
  2. freeze_graph.freeze_graph(pb_path, '', False, model_path, nodes_to_be_saved, save/restore_all, 'save/Const:0', pb_path, False, '')

首先,tf.train.write_graph将图结构保存.pb文件中,再调用tensorflow中包装好的接口(freeze_graph)保存模型的输出变量。其中,模型的输入变量会根据图结构自动回溯保存到.pb文件中。 注:这种方式主要应用.ckpt转.pb,由于.pb已经将变量freeze化,这里需要将input的batch_size定义为1。 2.应用.pb加载模型及参数做inference:

  1. with tf.Graph().as_default():
  2. output_graph_def = tf.GraphDef()
  3. with open(output_graph_path, "rb") as f:
  4. output_graph_def.ParseFromString(f.read())
  5. _ = tf.import_graph_def(output_graph_def, name="")
  6. with tf.Session() as sess:
  7. gnet_output = sess.graph.get_tensor_by_name(node_to_be_loaded)
  8. out_put = sess.run(gnet_output, feed_dict={input_img: input_img})

通过get_tensor_by_name获取模型的输入变量和输出变量,输入变量在feed_dict中加载数据,输出变量用于接受模型结果。

二、.pb用于finetune

由于tensorflow没有提供类似于加载.ckpt文件的restore接口,在做.pb文件用于模型finetune时,需要将模型中的trainable variables全部保存下来,并且在加载.pb文件时需要根据变量名称将变量值一一赋值到模型中。 应用步骤分为保存trainable variables到.pb文件和从.pb文件加载trainable variables:

1.保存trainable variables到.pb文件:

  1. var_list = tf.trainable_variables()
  2. constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, [var_list[i].name[:-2] for i in range(len(var_list))])
  3. with tf.gfile.FastGFile(pb_path, mode='wb') as f:
  4. f.write(constant_graph.SerializeToString())

finetune需要加载可训练的参数即可,因此这里只要保存trainable variables即可。

2.从.pb文件加载trainable variables:

  1. pb_para_dic = {} # 用于存储pb文件变量信息的字典结构 key-变量名,value-变量值
  2. with tf.Graph().as_default():
  3. output_graph_def = tf.GraphDef()
  4. with open(output_graph_path, "rb") as f:
  5. output_graph_def.ParseFromString(f.read())
  6. _ = tf.import_graph_def(output_graph_def, name="")
  7. with tf.Session() as sess:
  8. for item_var in var_list:
  9. pb_para_dic[item_var.name] = sess.run(sess.graph.get_tensor_by_name(item_var.name))
  10. with tf.Session(graph=graph, config=tfconfig) as sess:
  11. var_list = tf.trainable_variables()
  12. for item_var in var_list:
  13. sess.run(tf.assign(item_var, pb_para_dic[item_var.name])) # 将字典中的变量信息赋值图结构中

加载模型时需要重新定义一个新的图结构用于加载.pb文件中的权重,并且将权重keys和values放到一个字典中,然后在默认图结构中根据keys给session中的变量赋值。 注:在加载.pb模型时需要重新定义一个临时的graph域空间和临时的session,避免和图模型的定义空间冲突。

三、合并训练和测试模型到.pb文件

在合并训练和测试模型时,只需要将模型的trainable variables保存到.pb文件中,然后根据训练/测试代码建立的模型一一加载.pb文件中的trainable variables恢复模型。其中,应用方法与finetune的步骤基本一致。

应用步骤分为保存trainable variables到.pb文件,以及在训练和测试阶段,从.pb文件中一一加载trainable variables:

1.保存trainable variables到.pb文件:

  1. var_list = tf.trainable_variables()
  2. constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, [var_list[i].name[:-2] for i in range(len(var_list))])
  3. with tf.gfile.FastGFile(pb_path, mode='wb') as f:
  4. f.write(constant_graph.SerializeToString())

finetune需要加载可训练的参数即可,因此这里只要保存trainable variables即可。

2.训练/测试阶段从.pb文件加载trainable variables:

  1. pb_para_dic = {}
  2. with tf.Graph().as_default():
  3. output_graph_def = tf.GraphDef()
  4. with open(output_graph_path, "rb") as f:
  5. output_graph_def.ParseFromString(f.read())
  6. _ = tf.import_graph_def(output_graph_def, name="")
  7. with tf.Session() as sess:
  8. for item_var in var_list:
  9. pb_para_dic[item_var.name] = sess.run(sess.graph.get_tensor_by_name(item_var.name))
  10. with tf.Session(graph=graph, config=tfconfig) as sess:
  11. var_list = tf.trainable_variables()
  12. for item_var in var_list:
  13. sess.run(tf.assign(item_var, pb_para_dic[item_var.name]))

加载模型时需要重新定义一个新的图结构用于加载.pb文件中的权重,并且将权重keys和values放到一个字典中,然后在默认图结构中根据keys给session中的变量赋值。

注:在加载.pb模型时需要重新定义一个临时的graph域空间和临时的session,避免和图模型的定义空间冲突。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/958797
推荐阅读
相关标签
  

闽ICP备14008679号