pip install mriaug to use a 3D image library that is ~50x faster and simpler than torchio by
- only using PyTorch β full GPU(+autograd) support π₯
- being tiny: ~200 lines of code β no room for bugs π
while offering ~20 different augmentations (incl. MRI-specific operations) π©»
πΆ Normal users should use mriaug via niftiai, a deep learning framework for 3D images, since it
- provides
aug_transforms3d: A convenient function that compiles allmriaugmentations! - simplifies all the code needed for data loading, training, visualization...check it out here!
π΄ Experienced users can build their own framework upon mriaug (use niftiai/augment.py as a cheat sheet)
Let's create a 3D image tensor (with additional batch and channel dimension) and apply flip3d
import torch
from mriaug import flip3d
shape = (1, 1, 4, 4, 4)
x = torch.linspace(0, 1, 4**3).view(*shape)
x_flipped = flip3d(x)
print(x[..., 0, 0]) # tensor([[[0.0000, 0.2540, 0.5079, 0.7619]]])
print(x_flipped[..., 0, 0]) # tensor([[[0.7619, 0.5079, 0.2540, 0.0000]]])Explore the gallery to understand the usage and effect of all ~20 augmentations!
The popular libraries torchio and MONAI (utilizes torchio) often use ITK (CPU only) like this
PyTorch tensor β NumPy array β NiBabel image β ITK operation (C/C++) β NumPy array β PyTorch tensor
to augment a PyTorch tensor π¬ That's complicated and does not use the (for neural net training needed) GPU π
Instead, mriaug directly uses PyTorch (CPU & GPU support) resulting in
- ~50x fewer lines of code:
torchio: ~10,000 LOC,mriaug: ~200 LOC π€ - ~50x speedup on GPU π₯ based on the table below (run
speed.pyto reproduce) π¨
Click here, to see runtimes on a 256Β³ image in seconds (on AMD 5950X CPU and NVIDIA RTX 3090 GPU)
| Transformation | torchio |
mriaug on CPU |
mriaug on GPU |
Speedup vs. torchio |
|---|---|---|---|---|
| Flip | 0.014 | 0.012 | 0.002 | 7.5x |
| Affine | 0.297 | 0.608 | 0.011 | 27.9x |
| Warp | 0.951 | 0.850 | 0.009 | 103.3x |
| Bias Field | 3.258 | 0.081 | 0.002 | 1813.0x |
| Noise | 0.117 | 0.105 | 0.001 | 230.4x |
| Downsample | 0.282 | 0.013 | 0.000 | 592.3x |
| Ghosting | 0.241 | 0.170 | 0.003 | 78.3x |
| Spike | 0.265 | 0.172 | 0.003 | 88.8x |
| Motion | 0.696 | 0.540 | 0.009 | 78.6x |
Let's load an example 3D image x, show it with niftiview (used to create all images below)
define some arguments
size = (160, 196, 160)
zoom = torch.tensor([[-.2, 0, 0]])
rotate = torch.tensor([[0, .1, 0]])
translate = torch.tensor([[0, 0, .2]])
shear = torch.tensor([[0, .05, 0]])and run all augmentations (see runall.py):