Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bf672ae
Update sleap-io version to allow for fixed video reading
davidasamy Jul 5, 2023
be78313
Make SleapDataset general to be used by many pipelines
davidasamy Jul 6, 2023
760228b
added KorniaAugmenter, RandomUniformNoise; tmp nb; updated ymls
alckasoc Jul 20, 2023
820dfc5
Merge branch 'vincent/augmentation' of https://github.com/talmolab/sl…
alckasoc Jul 20, 2023
bd63957
deleted old/tmp files
alckasoc Jul 20, 2023
4e38413
added random translation px
alckasoc Jul 22, 2023
a3669a5
added docstring to randomtranslatepx
alckasoc Jul 24, 2023
adc17a7
added random brightness add; kornia's default brightness max val caps…
alckasoc Jul 24, 2023
0ed7b4e
finished kornia augmenter class; issue with slow loading
alckasoc Jul 25, 2023
eaa9a51
Merge remote-tracking branch 'origin/main' into vincent/augmentation
alckasoc Jul 25, 2023
5777586
merging with main
alckasoc Jul 25, 2023
8e68b4c
added mixup and random erasing/dropout patches
alckasoc Jul 25, 2023
52c989a
added random crop
alckasoc Jul 25, 2023
0764921
removed tmp.ipynb
alckasoc Jul 27, 2023
bff9be6
updated yml; simplified KorniaAugmenter; updated test_augmentation
alckasoc Jul 28, 2023
2fb87c1
updated yml; simplified KorniaAugmenter; updated test_augmentation
alckasoc Jul 28, 2023
7b9bcf6
added docstrings
alckasoc Jul 28, 2023
84230a8
added test cases for RandomUniformNoise
alckasoc Jul 28, 2023
000577a
black formatted test_augmentation
alckasoc Jul 28, 2023
4eb1e75
updated test case
alckasoc Jul 28, 2023
caf7b22
removed torchdata; made random crop optional
alckasoc Aug 1, 2023
af1526a
black formatted test_augmentation
alckasoc Aug 1, 2023
52060c5
updated test_augmentation with pytest.raises
alckasoc Aug 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ dependencies:
- lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027)
- cudnn
- pytorch
- kornia
- torchvision
- imageio
- av
- ffmpeg
- kornia
- matplotlib
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ dependencies:
- pydantic<=2.0
- lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027)
- cpuonly
- kornia
- torchvision
- imageio
- av
- ffmpeg
- kornia
- matplotlib
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion environment_osx-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ dependencies:
- pydantic<2.0
- lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027)
- pytorch
- kornia
- torchvision
- imageio
- av
- ffmpeg
- kornia
- matplotlib
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ dependencies = [
"imageio",
"imageio-ffmpeg",
"av",
"kornia",
"hydra-core",
"sleap-io>=0.0.7",
"kornia"
]
dynamic = ["version", "readme"]

Expand Down
266 changes: 241 additions & 25 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,93 @@
"""This module implements data pipeline blocks for augmentations."""
from typing import Optional
"""This module implements data pipeline blocks for augmentation operations."""
from typing import Tuple, Dict, Any, Optional, Union, Text
import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe
import kornia.augmentation as K
import kornia as K
from kornia.core import Tensor
from kornia.augmentation.container import AugmentationSequential
from kornia.augmentation._2d.geometric.base import GeometricAugmentationBase2D
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
from kornia.constants import Resample, SamplePadding
from kornia.geometry.transform import warp_affine
from kornia.augmentation.utils.param_validation import _range_bound


class RandomUniformNoise(IntensityAugmentationBase2D):
"""Data transformer for applying random uniform noise to input images.

This is a custom Kornia augmentation inheriting from `IntensityAugmentationBase2D`.
Uniform noise within (min_val, max_val) is applied to the entire input image.

Note: Inverse transform is not implemented and re-applying the same transformation
in the example below does not work when included in an AugmentationSequential class.

Args:
noise: 2-tuple (min_val, max_val); 0.0 <= min_val <= max_val <= 1.0.
p: probability for applying an augmentation. This param controls the augmentation probabilities
element-wise for a batch.
p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
probabilities batch-wise.
same_on_batch: apply the same transformation across the batch.
keepdim: whether to keep the output shape the same as input `True` or broadcast it
to the batch form `False`.

Examples:
>>> rng = torch.manual_seed(0)
>>> img = torch.rand(1, 1, 2, 2)
>>> RandomUniformNoise(min_val=0., max_val=0.1, p=1.)(img)
tensor([[[[0.9607, 0.5865],
[0.2705, 0.5920]]]])

To apply the exact augmentation again, you may take the advantage of the previous parameter state:
>>> input = torch.rand(1, 3, 32, 32)
>>> aug = RandomUniformNoise(min_val=0., max_val=0.1, p=1.)
>>> (aug(input) == aug(input, params=aug._params)).all()
tensor(True)

Ref: `kornia.augmentation._2d.intensity.gaussian_noise
<https://kornia.readthedocs.io/en/latest/_modules/kornia/augmentation/_2d/intensity/gaussian_noise.html#RandomGaussianNoise>`_.
"""

def __init__(
self,
noise: Tuple[float, float],
p: float = 0.5,
p_batch: float = 1.0,
clip_output: bool = True,
same_on_batch: bool = False,
keepdim: bool = False,
) -> None:
"""Initialize the class."""
super().__init__(
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
)
self.flags = {
"uniform_noise": _range_bound(noise, "uniform_noise", bounds=(0.0, 1.0))
}
self.clip_output = clip_output

def apply_transform(
self,
input: Tensor,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
) -> Tensor:
"""Compute the uniform noise, add, and clamp output."""
if "uniform_noise" in params:
uniform_noise = params["uniform_noise"]
else:
uniform_noise = (
torch.FloatTensor(input.shape)
.uniform_(flags["uniform_noise"][0], flags["uniform_noise"][1])
.to(input.device)
)
self._params["uniform_noise"] = uniform_noise
if self.clip_output:
return torch.clamp(
input + uniform_noise, 0.0, 1.0
) # RandomGaussianNoise doesn't clamp.
return input + uniform_noise


class KorniaAugmenter(IterDataPipe):
Expand All @@ -13,14 +99,36 @@ class KorniaAugmenter(IterDataPipe):
Attributes:
source_dp: The input `IterDataPipe` with examples that contain `"instances"` and
`"image"` keys.
crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
crop_p: Probability of applying random crop.
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: A scaling factor as a scalar float specifying the amount of scaling. A
random factor between `(1 - scale, 1 + scale)` will be sampled and applied
to both images and keypoints. If `None`, no scaling augmentation will be
applied.
probability: Probability of applying the transformations.
translate: tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
affine_p: Probability of applying random affine transformations.
uniform_noise: tuple of uniform noise `(min_noise, max_noise)`.
Must satisfy 0. <= min_noise <= max_noise <= 1.
uniform_noise_p: Probability of applying random uniform noise.
gaussian_noise_mean: The mean of the gaussian distribution.
gaussian_noise_std: The standard deviation of the gaussian distribution.
gaussian_noise_p: Probability of applying random gaussian noise.
contrast: The contrast factor to apply. Default: `(1.0, 1.0)`.
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: `(1.0, 1.0)`.
brightness_p: Probability of applying random brightness.
erase_scale: Range of proportion of erased area against input image. Default: `(0.02, 0.33)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(0.3, 3.3)`.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.

Notes:
This block expects the "image" and "instances" keys to be present in the input
Expand All @@ -38,37 +146,145 @@ class KorniaAugmenter(IterDataPipe):
def __init__(
self,
source_dp: IterDataPipe,
rotation: float = 15.0,
rotation: Optional[float] = 15.0,
scale: Optional[float] = 0.05,
probability: float = 0.5,
translate: Optional[Tuple[float, float]] = (0.02, 0.02),
affine_p: float = 0.5,
uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04),
uniform_noise_p: float = 0.5,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.5,
contrast: Optional[Tuple[float, float]] = (0.5, 2.0),
contrast_p: float = 0.5,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.5,
erase_scale: Optional[Tuple[float, float]] = (0.02, 0.1),
erase_ratio: Optional[Tuple[float, float]] = (0.3, 1.6),
erase_p: float = 0.5,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.5,
crop_hw: Tuple[int, int] = (0, 0),
crop_p: float = 0.0,
):
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
self.rotation = rotation
self.scale = (1 - scale, 1 + scale)
self.probability = probability
self.augmenter = K.AugmentationSequential(
K.RandomAffine(
degrees=self.rotation,
scale=self.scale,
p=self.probability,
keepdim=True,
same_on_batch=True,
),
self.translate = translate
self.affine_p = affine_p
self.uniform_noise = uniform_noise
self.uniform_noise_p = uniform_noise_p
self.gaussian_noise_mean = gaussian_noise_mean
self.gaussian_noise_std = gaussian_noise_std
self.gaussian_noise_p = gaussian_noise_p
self.contrast = contrast
self.contrast_p = contrast_p
self.brightness = brightness
self.brightness_p = brightness_p
self.erase_scale = erase_scale
self.erase_ratio = erase_ratio
self.erase_p = erase_p
self.mixup_lambda = mixup_lambda
self.mixup_p = mixup_p
self.crop_hw = crop_hw
self.crop_p = crop_p

aug_stack = []
if self.affine_p > 0:
aug_stack.append(
K.augmentation.RandomAffine(
degrees=self.rotation,
translate=self.translate,
scale=self.scale,
p=self.affine_p,
keepdim=True,
same_on_batch=True,
)
)
if self.uniform_noise_p > 0:
aug_stack.append(
RandomUniformNoise(
noise=self.uniform_noise,
p=self.uniform_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if self.gaussian_noise_p > 0:
aug_stack.append(
K.augmentation.RandomGaussianNoise(
mean=self.gaussian_noise_mean,
std=self.gaussian_noise_std,
p=self.gaussian_noise_p,
keepdim=True,
same_on_batch=True,
)
)
if self.contrast_p > 0:
aug_stack.append(
K.augmentation.RandomContrast(
contrast=self.contrast,
p=self.contrast_p,
keepdim=True,
same_on_batch=True,
)
)
if self.brightness_p > 0:
aug_stack.append(
K.augmentation.RandomBrightness(
brightness=self.brightness,
p=self.brightness_p,
keepdim=True,
same_on_batch=True,
)
)
if self.erase_p > 0:
aug_stack.append(
K.augmentation.RandomErasing(
scale=self.erase_scale,
ratio=self.erase_ratio,
p=self.erase_p,
keepdim=True,
same_on_batch=True,
)
)
if self.mixup_p > 0:
aug_stack.append(
K.augmentation.RandomMixUpV2(
lambda_val=self.mixup_lambda,
p=self.mixup_p,
keepdim=True,
same_on_batch=True,
)
)
if self.crop_p > 0:
if self.crop_hw[0] > 0 and self.crop_hw[1] > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=self.crop_hw,
pad_if_needed=True,
p=self.crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"crop_hw height and width must be greater than 0.")

self.augmenter = AugmentationSequential(
*aug_stack,
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

def __iter__(self):
"""Return an example dictionary with the augmented image and instance."""
"""Return an example dictionary with the augmented image and instances."""
for ex in self.source_dp:
img = ex["image"]
pts = ex["instances"]
pts_shape = pts.shape
pts = pts.reshape(-1, pts_shape[-2], pts_shape[-1])
img, pts = self.augmenter(img, pts)
pts = pts.reshape(pts_shape)
ex["image"] = img
ex["instances"] = pts
yield ex
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
)
aug_image, aug_instances = self.augmenter(image, instances)
yield {"image": aug_image, "instances": aug_instances.reshape(*inst_shape)}
Loading