LowFormer: Hardware Efficient Design for Convolutional Transformer Backbones (paper)
This is the official repository for "LowFormer: Hardware Efficient Design for Convolutional Transformer Backbones", which was accepted at WACV2025.
Authors: Moritz Nottebaum, Matteo Dunnhofer and Christian Micheloni
This repository contains code to train and test our LowFormer model, as well as to benchmark its speed. We also feature the base implementation of several backbones published in the recent years, as well as means to benchmark their execution time.
GPU Throughput and Top1-accuracy comparison (left), as well as effect of input resolution on GPU latency (right).
- 28.02.2025:
- added
lowformer_model.pyas standalone file - added
fast_eval.pyfor simplified ImageNet evaluation - refined structure of README.MD
- added
| Model | GPU Throughput | GPU Latency | params | MACs | top1 acc |
|---|---|---|---|---|---|
| LowFormer-B0 | 5988 | 0.3 | 14.1 | 944 | 78.4 |
| LowFormer-B1 | 4237 | 0.43 | 17.9 | 1410 | 79.9 |
| LowFormer-B1.5 | 2739 | 0.66 | 33.9 | 2573 | 81.2 |
| LowFormer-B3 | 1162 | 1.55 | 57.1 | 6098 | 83.6 |
All Checkpoints are downloadable and already present in the required folder structure. Simply put the downloaded folder structure into the main directory. Please refer to our paper for more information.
The lowformer_model.py file is a standalone file and does not include the dependencies of the repository anymore.
The script below is an example on how to use the lowformer_model.py file to get any LowFormer model:
import torch
from lowformer_model import get_lowformer_b0
# model = get_lowformer_b0(pretrained=True)
# model = get_lowformer_b1(pretrained=True)
# model = get_lowformer_b15(pretrained=True)
model = get_lowformer_b3(pretrained=True)
inp = torch.randn(5,3,224,224)
out = model(inp) # -> [5,1000]
You'll have to download the checkpoints here, if you want to use the pretrained version.
To run the code follow these steps.
Set up a conda environment and activate it:
conda create --name lowformer python=3.11
conda activate lowformer
Install requirements from requirements.txt:
pip install -r requirements.txt
You have to download ImageNet-1K and set the variable data_dir in configs/cls/imagenet/default.yaml for training and testing on ImageNet-1K.
If you want to evaluate and benchmark the latency of throughput you have to set --path in eval_cls_model.py in the argument parser at the beginning of the main() method or during execution of eval_cls_model.py.
You can download the Checkpoints and simply put the folder structure into the main folder (delete the existing exemplar .exp folder). Download link again.
Checkpoints for LowFormer-B0, -B1, -B1.5 and -B3 are available.
You can use imagenet_eval() function in fast_eval.py to evaluate models on the ImageNet validation set.
from fast_eval import imagenet_eval
your_model = get_some_model_function()
imagenet_eval(your_model) # prints result
(you do have to set the imagenet_path variable however in fast_eval.py)
With lowformer_imagenet_eval(modelname), you can very easily evaluate all LowFormer models,
where modelname is element of {"b0","b1","b15","b3"}.
To run on one GPU, specify the GPU-id with CUDA_VISIBLE_DEVICES and execute the following command:
CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 \
--nproc_per_node=1 --rdzv-endpoint localhost:29411 \
train_cls_model.py configs/cls/imagenet/b1_alternative.yaml \
--data_provider.image_size "[128,160,192,224]" \
--run_config.eval_image_size "[224]" \
--path .exp/cls/imagenet/b1_alternative/
To run on 8 GPUs, just run the following command:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes 1 \
--nproc_per_node=8 --rdzv-endpoint localhost:29411 \
train_cls_model.py configs/cls/imagenet/b1_alternative.yaml \
--data_provider.image_size "[128,160,192,224]" \
--run_config.eval_image_size "[224]" \
--path .exp/cls/imagenet/b1_alternative/
Caveat: The total batch size in the config file is multiplied with the GPU instances and as well is the learning rate in the config file!
In order to simulate a bigger batch size, there is a parameter in the configs called bsizemult, which is normally set to 1. The learning rate is also multiplied with it, as bsizemult increases the effective batch size.
For testing and speed analysis eval_cls_model.py can be used.
We also feature a vast library of popular backbone architectures. We adapted their code such that they can be converted to torchscript and onnx for speed measurement. For a list of all featured architectures look at featured_models.txt, containing one example for each architecture (architecture: fastvit, model: fastvit_t8 ; architecture: mobileone, model: mobileones3 ).
To evaluate a model given in configs/cls/imagenet, just run the following command:
python eval_cls_model.py b1 --image_size 224 --batch_size 100 --gpu 6
The following command runs the model LowFormer-B1 (from configs/cls/imagenet) for 400 iterations, with a batch size of 200, it uses torchscript optimization (optit) and has an input resolution of 224x224 (throughput measurement):
python eval_cls_model.py b1 --image_size 224 --batch_size 200 --testrun --iterations 400 --gpu 6 --optit
You can benchmark latency with a torchscript converted version of the model and utilize torch inference optimization (see here for more information)
python eval_cls_model.py b1 --image_size 224 --batch_size 1 --testrun --latency --optit --iterations 4000 --gpu 6 --optit --jobs 1
You can also convert LowFormer-B1 to onnx and benchmark its latency (the onnx conversion is already implemented in eval_cls_model.py):
python eval_cls_model.py b1 --image_size 224 --batch_size 1 --testrun --latency --onnxrun --iterations 4000 --gpu 6 --optit --jobs 1
Because of torchscript conversion the checkpoint cannot completely be loaded, as the
ClsHeadTorchScriptclass is used instead ofClsHeadinlowformer/models/lowformer/cls.py. This can however be fixed by adapting the checkpoint if needed.
It is also possible to measure latency with n parallel processes executing the model by setting the --jobs variable:
python eval_cls_model.py b1 --image_size 224 --batch_size 1 --testrun --latency --optit --iterations 4000 --gpu 6 --optit --jobs 4
When you append the argument --other followed by a string, you can run a lot of other backbones. Most of these backbones do not load their weights, so this functionality is purely for speed measurement (but could be extended for evaluation). The following command benchmarks MobileOne-S1 [1]:
python eval_cls_model.py b1 --image_size 224 --batch_size 1 --testrun --latency --onnxrun --iterations 4000 --gpu 6 --optit --jobs 1 --other mobileones1
Please see below in acknowledgements for a link to the repository of MobileOne publication.
To train a custom architecure simply adapt the lowformer_cls_b1() method in lowformer/models/lowformer/cls.py. Replace the method call lowformer_backbone_b1(**kwargs) (returns a pytorch model) with your own model. Then simply copy the b1.yaml config file and name it however you want.
Then run the training command specified under section "Training", but adapting the config file path.
You need to change the variable name: b1 in a config file to your model name and adapt model_dict in method create_cls_model in lowformer/cls_model_zoo.py accordingly, then add own methods in lowformer/models/lowformer/cls.py and lowformer/models/lowformer/backbone.py for your model.
If you want to add layers of your own you should add them in lowformer/models/nn/ops.py.
We thank the contributors of the codebase and the paper "EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction" [2], which was used as a base for this repository.
Caveat: There are two EfficientViT papers, the other one is called "Efficientvit: Memory efficient vision transformer with cascaded group attention"
We also thank the contributors of the openly available code of the many backbones architectures we feature in this repository. Here is a list to all their repositories:
FastViT, Efficientmodulation, MobileViG, iFormer, MobileOne, FFNet, GhostNetV2, EfficientViT, EdgeViT, PVTv2, FAT, EfficientFormer, SHViT, RepViT
We hope you find our work useful. If you would like to acknowledge it in your project, please use the following citation:
@article{Nottebaum2024LowFormerHE,
title={LowFormer: Hardware Efficient Design for Convolutional Transformer Backbones},
author={Moritz Nottebaum and Matteo Dunnhofer and Christian Micheloni},
journal={2025 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
year={2024},
pages={7008-7018},
url={https://api.semanticscholar.org/CorpusID:272423686}
}
Papers mentioned in this README:
[1] Vasu, Pavan Kumar Anasosalu, et al. "Mobileone: An improved one millisecond mobile backbone." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2023.
[2] Cai, Han, et al. "Efficientvit: Lightweight multi-scale attention for high-resolution dense prediction." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023.