当前位置:   article > 正文

【模型剪枝】基于DepGraph(依赖图)完成复杂模型的一键剪枝_depgraph: towards any structural pruning

depgraph: towards any structural pruning

这里提出了一种非深度图算法DepGraph,实现了架构通用的结构化剪枝,适用于CNNs, Transformers, RNNs, GNNsLLM语言模型等网络。
该算法能够自动地分析复杂的结构耦合,从而正确地移除参数实现网络加速。基于DepGraph算法,我们开发了PyTorch结构化剪枝框架 Torch-Pruning。不同于依赖Masking实现的“模拟剪枝”,该框架能够实际地移除参数和通道,降低模型推理成本。在DepGraph的帮助下,研究者和工程师无需再与复杂的网络结构斗智斗勇,可以轻松完成复杂模型的一键剪枝。
DepGraph算法 论文标题:DepGraph: Towards Any Structural Pruning
DepGraph算法 论文地址:https://arxiv.org/abs/2301.12900
Torch-Pruning工具 github仓库:https://github.com/VainF/Torch-Pruning

一、 下载Torch-Pruning工具

在这里插入图片描述

二、 准备DeepLabV3+代码

这里我使用的是B站UP主bubbliiiing复现的DeepLabV3+代码
github仓库地址:https://github.com/bubbliiiing/deeplabv3-plus-pytorch

在这里插入图片描述

三、 baseline模型训练

在剪枝之前,我们需要正常准备数据,训练出最佳的模型。在剪枝之前,模型的大小为12.9MB,测试效果如下
在这里插入图片描述
在这里插入图片描述

四、 开始剪枝

步骤一、将下载的Torch-Pruning工具库中的torch_pruning文件夹复制到DeepLabV3+代码根目录下
在这里插入图片描述

步骤二、运行下面代码,实现结构化剪枝。

# DeeplabV3 prune code
# 2024/5/1

import torch
import torch_pruning as tp

device = 'cuda'

# Step 0. 加载模型和权重
# 加载模型和预训练权重
model = torch.load('logs/before_prune.pth',map_location=device)
model.eval()
inputs = torch.randn(1, 3, 640, 640).to(device)

# 统计剪枝前参数量
macs, nparams = tp.utils.count_ops_and_params(model, inputs)
print("剪枝前: macs=%d, nparams=%d"%(macs, nparams))

# Step 1. 重要性评判器
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning

# Step 2. 初始化剪枝器
# Step 2.1. head不参与剪枝
# 我这样用的是语义分割模型,cls_conv是里面的结构命名,具体的参数名需要根据自己实际模型中的网络命名进行修改
ignored_layers = []
for name, m in model.named_modules():
    if 'cls_conv' in name:
        ignored_layers.append(m)
# Step 2.2. 初始化剪枝器
iterative_steps = 1 # progressive pruning
prune_rate = 0.5 # 剪枝率
pruner = tp.pruner.MagnitudePruner(model=model,
                                   example_inputs=inputs,
                                   importance=imp,
                                   iterative_steps=iterative_steps,
                                   pruning_ratio=prune_rate,
                                   ignored_layers=ignored_layers,)


# Step 4. 进行剪枝
pruner.step()
# 统计剪枝后参数量
macs, nparams_pruned = tp.utils.count_ops_and_params(model, inputs)
print("剪枝后: macs=%d, nparams=%d"%(macs, nparams_pruned))
params_ratio = nparams_pruned / nparams
print("参数量比: ratio = %f" %(params_ratio))


# Step 6. save
torch.save(model, 'after_pruned.pth') # without .state_dict

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

在这里插入图片描述

注意,我们需要使用torch.load()torch.save()将模型结构和权重完整的保存下来,不用使用只保留权重(state_dict)的方式。可以看到剪枝后,大小为3.54MB,体积变为了baseline模型的1/4,在不进行精度恢复训练之前,测试一下模型效果,发现完全无效,这是因为模型结构发生了破坏(剪枝),所以下一步还需要精度恢复训练。
在这里插入图片描述
在这里插入图片描述

五、 精度恢复训练

剪枝完后,我们需要使用torch.load的方式加载3.54MB的剪枝模型,然后按照正常的训练流程,对剪枝模型进行精度恢复训练。

model = torch.load(model_path,map_location=device)
  • 1

训练后,我们再一次测试3.54MB剪枝模型的效果,发现精度已经恢复,且几乎无损,模型大小却已经压缩为原来的1/4
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/745388
推荐阅读
相关标签
  

闽ICP备14008679号