当前位置:   article > 正文

mxnet如何将.json、.params模型转换为gluon模型_mxnet模型文件json

mxnet模型文件json

Gluon是MXNet的动态图接口;Gluon学习了Keras,Chainer,和Pytorch的优点,并加以改进。接口更简单,且支持动态图(Imperative)编程。相比TF,Caffe2等静态图(Symbolic)框架更加灵活易用。同时Gluon还继承了MXNet速度快,省显存,并行效率高的优点,并支持静、动态图混用,比Pytorch更快。——转自解浚源知乎

题目中提及的.json、.params模型如下所示:

 mxnet版本高于1.2.1可以使用如下方法:

  1. net = gluon.nn.SymbolBlock.imports('resnet18-symbol.json',
  2. ['data'],
  3. param_file='resnet18-0000.params',
  4. ctx=mx.gpu())

以前的版本可以使用:

  1. sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18', 0)
  2. net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
  3. # Set the params
  4. net_params = net.collect_params()
  5. for param in arg_params:
  6. if param in net_params:
  7. net_params[param]._load_init(arg_params[param], ctx=ctx)
  8. for param in aux_params:
  9. if param in net_params:
  10. net_params[param]._load_init(aux_params[param], ctx=ctx)

net_params是ParameterDict类型,也就是value为Parameter类型的字典,其可以通过data()函数获得其具体参数,参数类型为NDArray,如:

arraya = net_params['stage4_unit3_bn2_beta'].data()

arg_params,aux_params均是一个字典类型,他们的结构均为"参数名称":NDarray,如:

arrayb = arg_params['stage4_unit3_bn2_beta']

 需要说明的是:

inputs=mx.sym.var('data')

是使用静态图的方法生成一个输入节点名为'data' ,arg_params是主要参数如weights,aux_params是辅助参数主要是bias或者是batchnorm中的一些参数。

疑问:以上方法是对模型中的参数一个个的load,虽然已经装载进去了但是net_params内部的参数的shape扔然是None这是不解的地方,如下所示;

 疑问5.29日解决,原因是:虽然模型函数已经加载了参数,但是mxnet模型推断机制是在模型进行一次前向计算(forward)后才完成,如下图所示:

SymbolBlock是继承于block有好多的Sequence的方法,其并不能使用,如net[0]因为其内部并没有__getitems__()函数所以这种访问模型内部参数字典的方法并不适用

第一种方法的import函数内容如下:

  1. def imports(symbol_file, input_names, param_file=None, ctx=None):
  2. """Import model previously saved by `HybridBlock.export` or
  3. `Module.save_checkpoint` as a SymbolBlock for use in Gluon.
  4. Parameters
  5. ----------
  6. symbol_file : str
  7. Path to symbol file.
  8. input_names : list of str
  9. List of input variable names
  10. param_file : str, optional
  11. Path to parameter file.
  12. ctx : Context, default None
  13. The context to initialize SymbolBlock on.
  14. Returns
  15. -------
  16. SymbolBlock
  17. SymbolBlock loaded from symbol and parameter files.
  18. Examples
  19. --------
  20. >>> net1 = gluon.model_zoo.vision.resnet18_v1(
  21. ... prefix='resnet', pretrained=True)
  22. >>> net1.hybridize()
  23. >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
  24. >>> out1 = net1(x)
  25. >>> net1.export('net1', epoch=1)
  26. >>>
  27. >>> net2 = gluon.SymbolBlock.imports(
  28. ... 'net1-symbol.json', ['data'], 'net1-0001.params')
  29. >>> out2 = net2(x)
  30. """
  31. sym = symbol.load(symbol_file)
  32. if isinstance(input_names, str):
  33. input_names = [input_names]
  34. inputs = [symbol.var(i) for i in input_names]
  35. ret = SymbolBlock(sym, inputs)
  36. if param_file is not None:
  37. ret.collect_params().load(param_file, ctx=ctx)
  38. return ret

特殊情况下可以用到

  1. def load_model_finetune(model_path, ctx, output_layer=None, expect_prefix=None, is_train=True):
  2. model_name, epoch = model_path.split(SLASH)[-1].split('-')
  3. model_path = os.path.join(SLASH, *model_path.split(SLASH)[0:-1], model_name)
  4. syms, arg_params, aux_params = mx.model.load_checkpoint(model_path, int(epoch))
  5. if output_layer is not None:
  6. all_layers = syms.get_internals()
  7. syms = all_layers[output_layer]
  8. net = mx.gluon.nn.SymbolBlock(outputs=syms, inputs=mx.sym.var('data'))
  9. net_params = net.collect_params()
  10. print(net_params.keys())
  11. for param in arg_params:
  12. if param in net_params:
  13. if expect_prefix is not None:
  14. if expect_prefix in param:
  15. continue
  16. net_params[param]._load_init(arg_params[param], ctx=ctx)
  17. if not is_train:
  18. net_params[param].setattr("grad_req", "null")
  19. for param in aux_params:
  20. if param in net_params:
  21. if expect_prefix is not None:
  22. if expect_prefix in param:
  23. continue
  24. net_params[param]._load_init(aux_params[param], ctx=ctx)
  25. if not is_train:
  26. net_params[param].setattr("grad_req", "null")
  27. return net

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/346271
推荐阅读