This is the source code for our paper: Channel Pruning Guided by Spatial and Channel Attention for DNNs in Intelligent Edge Computing. A brief introduction of this work is as follows:
Deep Neural Networks (DNNs) have achieved remarkable success in many computer vision tasks recently, but the huge number of parameters and the high computation overhead hinder their deployments on resource-constrained edge devices. It is worth noting that channel pruning is an effective approach for compressing DNN models. A critical challenge is to determine which channels are to be removed, so that the model accuracy will not be negatively affected. In this paper, we first propose Spatial and Channel Attention (SCA), a new attention module combining both spatial and channel attention that respectively focuses on "where" and "what" are the most informative parts. Guided by the scale values generated by SCA for measuring channel importance, we further propose a new channel pruning approach called Channel Pruning guided by Spatial and Channel Attention (CPSCA). Experimental results indicate that SCA achieves the best inference accuracy, while incurring negligibly extra resource consumption, compared to other state-of-the-art attention modules. Our evaluation on two benchmark datasets shows that, with the guidance of SCA, our CPSCA approach achieves higher inference accuracy than other state-of-the-art pruning methods under the same pruning ratios.
深度神经网络(DNN)近来在众多计算机视觉任务中取得显著成功,但其海量参数和高计算开销阻碍了其在资源受限边缘设备上的部署。值得注意的是,通道剪枝是一种有效的DNN模型压缩方法。其核心挑战在于确定哪些通道可以被移除,从而确保模型精度不受负面影响。本文首先提出空间与通道注意力(SCA)模块,这是一种融合空间注意力(聚焦"何处"是最具信息量的区域)与通道注意力(聚焦"什么"是最重要特征)的新型注意力机制。基于SCA生成的用于衡量通道重要性的尺度值指导,我们进一步提出了一种名为"基于空间与通道注意力的通道剪枝"(CPSCA)的新方法。实验结果表明:与其他最先进的注意力模块相比,SCA在仅增加微不足道的额外资源消耗的同时,实现了最佳的推理精度。在两个基准数据集上的评估表明:在SCA的指导下,我们的CPSCA方法在相同剪枝率条件下,比其他最先进的剪枝方法实现了更高的推理精度。
This paper has been accepted and has been published by Applied Soft Computing (ASOC), and the preprint version can be downloaded from here. You can also download the formal version from here.
We only provide our SCA and CPSCA here. You can find the implementation of other attention models mentioned in our paper from PytorchInsight. Due to some reason, we didn't provide the scaler_for_prune.txt file required by prune.py in the released code. If you want to know how to generate it, please contact the 1st author with 18800191663@163.com.
- Python 3.5+
- PyTorch (1.x)
- torchvision
- NumPy
CPSCA/
├── source_code/
│ ├── SCA.py # SCA attention module (SpatialGroupEnhance + ChannelGate)
│ ├── resnet.py # Pruned ResNet (used in prune.py after channel selection)
│ ├── resnet_SCA.py # ResNet-56 with SCA modules embedded in each BasicBlock
│ ├── train.py # Train the SCA-augmented ResNet and record scale values
│ ├── prune.py # Channel pruning guided by SCA scale values
│ ├── compute_flops.py # FLOPs / parameter counter for the pruned model
│ └── models/ # Backbone definitions (VGG / ResNet variants for pruning)
│ ├── __init__.py
│ ├── resnet.py
│ └── vgg.py
└── README.md
The Spatial and Channel Attention (SCA) module that combines a spatial group-enhance branch with a channel gate. Two main components:
| Class | Description |
|---|---|
ChannelPool |
Concatenates avg_pool and max_pool outputs along the channel dimension. |
SpatialGroupEnhance(groups) |
Splits channels into groups groups, computes a per-group spatial descriptor, normalizes it, and produces a learnable spatial attention map via sigmoid(weight * t + bias). |
ChannelGate(gate_channels) |
Applies avg/max pooling + a shared MLP to produce a per-channel attention vector (the "scale value" used for pruning). |
SCA(planes, groups=16) |
Concatenates SpatialGroupEnhance and ChannelGate outputs along the channel axis and projects back to planes channels via a 1×1 convolution. |
SCA outputs two tensors: the attention-refined feature map and the channel scale values produced by ChannelGate. These scale values are the channel-importance scores used by prune.py.
A CIFAR-style ResNet-56 (ResNet(BasicBlock, depth=56)) where every BasicBlock inserts an SCA(planes, 16) module after each 3×3 convolution and before the batch normalization. This is the backbone used to learn the channel-importance scale values.
| Stage | Output channels | # Blocks |
|---|---|---|
conv1 |
16 | 1 |
layer1 |
16 | n = (depth-2)/6 = 9 |
layer2 (stride 2) |
32 | 9 |
layer3 (stride 2) |
64 | 9 |
avgpool → fc |
num_classes (10 / 100) | — |
__all__ = ['resnet'] exposes the resnet(depth, dataset, cfg=None) factory.
A ResNet-56 (ResNet1(BasicBlock, depth=56)) without SCA modules. It accepts a cfg list specifying the number of channels for every conv layer and is used after pruning to build the compact model. Without cfg, all conv layers keep the original [16, 32, 64] widths.
Trains the SCA-augmented ResNet on CIFAR-10 / CIFAR-100 and records the channel scale values.
| Argument | Default | Description |
|---|---|---|
--prefix |
required | Prefix for log files and checkpoints (e.g. resnet56-sca-cifar100). |
--batch-size |
256 | Mini-batch size. |
--lr |
0.01 | Initial learning rate. |
--momentum |
0.9 | SGD momentum. |
--weight-decay |
5e-4 | L2 weight decay. |
--epochs |
300 | Total training epochs. |
--workers |
50 | Number of data-loading workers. |
--resume |
"" | Path to a checkpoint to resume from. |
--half / --cpu / --evaluate |
flags | Mixed-precision / CPU-only / evaluation-only mode. |
During training, the per-channel scale values output by SCA are written to scaler_for_prune.txt, which is consumed by prune.py.
Loads the pre-trained SCA model and the recorded scaler_for_prune.txt, ranks channels by their scale values, removes the lowest-importance channels, and fine-tunes the compact network.
Key steps:
- Parse the scale file and build a
scalevector of per-channel importance scores. - Build both the pruned backbone (
resnet) and the SCA-augmented model (resnet_SCA), then copy the matching weights from the trained checkpoint. - Sort the scale values of each layer and keep the top-
kchannels, wherekis determined by the target pruning ratio. - Construct a new
cfglist and instantiate the compactresnetwith it. - Fine-tune the pruned model on the target dataset (CIFAR-10 / CIFAR-100).
| Argument | Default | Description |
|---|---|---|
-v |
A |
Pruning-version selector (controls which layers/channels to keep). |
--dataset |
cifar100 |
Target dataset (cifar10 or cifar100). |
--epochs |
300 | Fine-tuning epochs. |
--batch-size / --lr |
64 / 0.1 | Fine-tuning hyperparameters. |
--save |
. |
Directory for the pruned-model checkpoint. |
--resume |
"" | Resume from a checkpoint. |
--evaluate |
flag | Evaluate without training. |
Required input file: scaler_for_prune.txt — produced by train.py and shipped in the same directory. As noted above, this file is not included in the public release; please contact the authors to obtain it.
A utility (adapted from simochen/model-tools) that counts the FLOPs and parameter count of a model. Used to report the compactness of the pruned network.
| Function | Description |
|---|---|
print_model_param_nums(model) |
Print the total number of trainable parameters (in millions). |
print_model_param_flops(model, input_res=224) |
Estimate and print the total number of FLOPs. |
Register forward hooks on Conv2d, Linear, BatchNorm2d, ReLU, pooling and upsample layers to accumulate the cost of every operation.
Reusable backbone definitions that can be plugged into the pruning pipeline:
| File | Description |
|---|---|
models/resnet.py |
ResNet for CIFAR-10/100 with a cfg list for per-layer channel widths. |
models/vgg.py |
VGG-11/13/16/19 (defaultcfg[depth]) for CIFAR-10/100, also cfg-aware. |
Both expose a cfg argument so that a pruned width list can be passed in directly.
# 1. Train the SCA-augmented ResNet and record channel scale values
cd source_code
python train.py --prefix resnet56-sca-cifar100 --dataset cifar100
# 2. Place scaler_for_prune.txt (obtained from the authors) in source_code/
# 3. Run channel pruning guided by the SCA scale values
python prune.py --dataset cifar100 -v A
# 4. Measure the FLOPs / parameter count of the pruned model
python -c "from compute_flops import *; import torch; \
from resnet import resnet; m = resnet(depth=56).cuda(); \
print_model_param_nums(m); print_model_param_flops(m, input_res=32)"Note: The hard-coded path
infile='/scaler_for_prune.txt'inprune.pymeans the scale file must be located at the filesystem root/when running the script as-is. Adjust the path to match your environment if needed.
If you use these models in your research, please cite:
@article{LIU2021107636,
title = {Channel pruning guided by spatial and channel attention for DNNs in intelligent edge computing},
journal = {Applied Soft Computing},
volume = {110},
pages = {107636},
year = {2021},
issn = {1568-4946},
doi = {https://doi.org/10.1016/j.asoc.2021.107636},
url = {https://www.sciencedirect.com/science/article/pii/S1568494621005573},
author = {Mengran Liu and Weiwei Fang and Xiaodong Ma and Wenyuan Xu and Naixue Xiong and Yi Ding},
}
Other related works by the same group:
- UAV-DDPG — Computation offloading for UAV-assisted MEC with DDPG.
- VN-MADDPG — Multi-agent DDPG for vehicular networks.
- MTACP — IMPALA-based multi-task reinforcement learning.
Mengran Liu (18800191663@163.com)
Please note that the open source code in this repository was mainly completed by the graduate student author during his master's degree study. Since the author did not continue to engage in scientific research work after graduation, it is difficult to continue to maintain and update these codes. We sincerely apologize that these codes are for reference only.