Skip to content

fangvv/CPSCA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 

Repository files navigation

CPSCA

This is the source code for our paper: Channel Pruning Guided by Spatial and Channel Attention for DNNs in Intelligent Edge Computing. A brief introduction of this work is as follows:

Deep Neural Networks (DNNs) have achieved remarkable success in many computer vision tasks recently, but the huge number of parameters and the high computation overhead hinder their deployments on resource-constrained edge devices. It is worth noting that channel pruning is an effective approach for compressing DNN models. A critical challenge is to determine which channels are to be removed, so that the model accuracy will not be negatively affected. In this paper, we first propose Spatial and Channel Attention (SCA), a new attention module combining both spatial and channel attention that respectively focuses on "where" and "what" are the most informative parts. Guided by the scale values generated by SCA for measuring channel importance, we further propose a new channel pruning approach called Channel Pruning guided by Spatial and Channel Attention (CPSCA). Experimental results indicate that SCA achieves the best inference accuracy, while incurring negligibly extra resource consumption, compared to other state-of-the-art attention modules. Our evaluation on two benchmark datasets shows that, with the guidance of SCA, our CPSCA approach achieves higher inference accuracy than other state-of-the-art pruning methods under the same pruning ratios.

深度神经网络(DNN)近来在众多计算机视觉任务中取得显著成功,但其海量参数和高计算开销阻碍了其在资源受限边缘设备上的部署。值得注意的是,通道剪枝是一种有效的DNN模型压缩方法。其核心挑战在于确定哪些通道可以被移除,从而确保模型精度不受负面影响。本文首先提出空间与通道注意力(SCA)模块,这是一种融合空间注意力(聚焦"何处"是最具信息量的区域)与通道注意力(聚焦"什么"是最重要特征)的新型注意力机制。基于SCA生成的用于衡量通道重要性的尺度值指导,我们进一步提出了一种名为"基于空间与通道注意力的通道剪枝"(CPSCA)的新方法。实验结果表明:与其他最先进的注意力模块相比,SCA在仅增加微不足道的额外资源消耗的同时,实现了最佳的推理精度。在两个基准数据集上的评估表明:在SCA的指导下,我们的CPSCA方法在相同剪枝率条件下,比其他最先进的剪枝方法实现了更高的推理精度。

This paper has been accepted and has been published by Applied Soft Computing (ASOC), and the preprint version can be downloaded from here. You can also download the formal version from here.

We only provide our SCA and CPSCA here. You can find the implementation of other attention models mentioned in our paper from PytorchInsight. Due to some reason, we didn't provide the scaler_for_prune.txt file required by prune.py in the released code. If you want to know how to generate it, please contact the 1st author with 18800191663@163.com.

Required software

  • Python 3.5+
  • PyTorch (1.x)
  • torchvision
  • NumPy

Project Structure

CPSCA/
├── source_code/
│   ├── SCA.py              # SCA attention module (SpatialGroupEnhance + ChannelGate)
│   ├── resnet.py           # Pruned ResNet (used in prune.py after channel selection)
│   ├── resnet_SCA.py       # ResNet-56 with SCA modules embedded in each BasicBlock
│   ├── train.py            # Train the SCA-augmented ResNet and record scale values
│   ├── prune.py            # Channel pruning guided by SCA scale values
│   ├── compute_flops.py    # FLOPs / parameter counter for the pruned model
│   └── models/             # Backbone definitions (VGG / ResNet variants for pruning)
│       ├── __init__.py
│       ├── resnet.py
│       └── vgg.py
└── README.md

Core Modules

SCA (SCA.py)

The Spatial and Channel Attention (SCA) module that combines a spatial group-enhance branch with a channel gate. Two main components:

Class Description
ChannelPool Concatenates avg_pool and max_pool outputs along the channel dimension.
SpatialGroupEnhance(groups) Splits channels into groups groups, computes a per-group spatial descriptor, normalizes it, and produces a learnable spatial attention map via sigmoid(weight * t + bias).
ChannelGate(gate_channels) Applies avg/max pooling + a shared MLP to produce a per-channel attention vector (the "scale value" used for pruning).
SCA(planes, groups=16) Concatenates SpatialGroupEnhance and ChannelGate outputs along the channel axis and projects back to planes channels via a 1×1 convolution.

SCA outputs two tensors: the attention-refined feature map and the channel scale values produced by ChannelGate. These scale values are the channel-importance scores used by prune.py.

ResNet-SCA (resnet_SCA.py)

A CIFAR-style ResNet-56 (ResNet(BasicBlock, depth=56)) where every BasicBlock inserts an SCA(planes, 16) module after each 3×3 convolution and before the batch normalization. This is the backbone used to learn the channel-importance scale values.

Stage Output channels # Blocks
conv1 16 1
layer1 16 n = (depth-2)/6 = 9
layer2 (stride 2) 32 9
layer3 (stride 2) 64 9
avgpool → fc num_classes (10 / 100)

__all__ = ['resnet'] exposes the resnet(depth, dataset, cfg=None) factory.

Pruned ResNet (resnet.py)

A ResNet-56 (ResNet1(BasicBlock, depth=56)) without SCA modules. It accepts a cfg list specifying the number of channels for every conv layer and is used after pruning to build the compact model. Without cfg, all conv layers keep the original [16, 32, 64] widths.

Training (train.py)

Trains the SCA-augmented ResNet on CIFAR-10 / CIFAR-100 and records the channel scale values.

Argument Default Description
--prefix required Prefix for log files and checkpoints (e.g. resnet56-sca-cifar100).
--batch-size 256 Mini-batch size.
--lr 0.01 Initial learning rate.
--momentum 0.9 SGD momentum.
--weight-decay 5e-4 L2 weight decay.
--epochs 300 Total training epochs.
--workers 50 Number of data-loading workers.
--resume "" Path to a checkpoint to resume from.
--half / --cpu / --evaluate flags Mixed-precision / CPU-only / evaluation-only mode.

During training, the per-channel scale values output by SCA are written to scaler_for_prune.txt, which is consumed by prune.py.

Channel Pruning (prune.py)

Loads the pre-trained SCA model and the recorded scaler_for_prune.txt, ranks channels by their scale values, removes the lowest-importance channels, and fine-tunes the compact network.

Key steps:

  1. Parse the scale file and build a scale vector of per-channel importance scores.
  2. Build both the pruned backbone (resnet) and the SCA-augmented model (resnet_SCA), then copy the matching weights from the trained checkpoint.
  3. Sort the scale values of each layer and keep the top-k channels, where k is determined by the target pruning ratio.
  4. Construct a new cfg list and instantiate the compact resnet with it.
  5. Fine-tune the pruned model on the target dataset (CIFAR-10 / CIFAR-100).
Argument Default Description
-v A Pruning-version selector (controls which layers/channels to keep).
--dataset cifar100 Target dataset (cifar10 or cifar100).
--epochs 300 Fine-tuning epochs.
--batch-size / --lr 64 / 0.1 Fine-tuning hyperparameters.
--save . Directory for the pruned-model checkpoint.
--resume "" Resume from a checkpoint.
--evaluate flag Evaluate without training.

Required input file: scaler_for_prune.txt — produced by train.py and shipped in the same directory. As noted above, this file is not included in the public release; please contact the authors to obtain it.

FLOPs / Parameter Counter (compute_flops.py)

A utility (adapted from simochen/model-tools) that counts the FLOPs and parameter count of a model. Used to report the compactness of the pruned network.

Function Description
print_model_param_nums(model) Print the total number of trainable parameters (in millions).
print_model_param_flops(model, input_res=224) Estimate and print the total number of FLOPs.

Register forward hooks on Conv2d, Linear, BatchNorm2d, ReLU, pooling and upsample layers to accumulate the cost of every operation.

Backbones (models/)

Reusable backbone definitions that can be plugged into the pruning pipeline:

File Description
models/resnet.py ResNet for CIFAR-10/100 with a cfg list for per-layer channel widths.
models/vgg.py VGG-11/13/16/19 (defaultcfg[depth]) for CIFAR-10/100, also cfg-aware.

Both expose a cfg argument so that a pruned width list can be passed in directly.

Usage

# 1. Train the SCA-augmented ResNet and record channel scale values
cd source_code
python train.py --prefix resnet56-sca-cifar100 --dataset cifar100

# 2. Place scaler_for_prune.txt (obtained from the authors) in source_code/

# 3. Run channel pruning guided by the SCA scale values
python prune.py --dataset cifar100 -v A

# 4. Measure the FLOPs / parameter count of the pruned model
python -c "from compute_flops import *; import torch; \
           from resnet import resnet; m = resnet(depth=56).cuda(); \
           print_model_param_nums(m); print_model_param_flops(m, input_res=32)"

Note: The hard-coded path infile='/scaler_for_prune.txt' in prune.py means the scale file must be located at the filesystem root / when running the script as-is. Adjust the path to match your environment if needed.

Citation

If you use these models in your research, please cite:

@article{LIU2021107636,
  title  = {Channel pruning guided by spatial and channel attention for DNNs in intelligent edge computing},
  journal = {Applied Soft Computing},
  volume = {110},
  pages  = {107636},
  year   = {2021},
  issn   = {1568-4946},
  doi    = {https://doi.org/10.1016/j.asoc.2021.107636},
  url    = {https://www.sciencedirect.com/science/article/pii/S1568494621005573},
  author = {Mengran Liu and Weiwei Fang and Xiaodong Ma and Wenyuan Xu and Naixue Xiong and Yi Ding},
}

For more

Other related works by the same group:

  • UAV-DDPG — Computation offloading for UAV-assisted MEC with DDPG.
  • VN-MADDPG — Multi-agent DDPG for vehicular networks.
  • MTACP — IMPALA-based multi-task reinforcement learning.

Contact

Mengran Liu (18800191663@163.com)

Please note that the open source code in this repository was mainly completed by the graduate student author during his master's degree study. Since the author did not continue to engage in scientific research work after graduation, it is difficult to continue to maintain and update these codes. We sincerely apologize that these codes are for reference only.

About

Code for paper "Channel Pruning Guided by Spatial and Channel Attention for DNNs in Intelligent Edge Computing"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages