Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions sleap_nn/data/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BaseDataset(Dataset):
np_chunks: If `True`, `.npz` chunks are generated and samples are loaded from
these chunks during training. Else, in-memory caching is used.
np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used.
use_existing_chunks: Use existing chunks in the `np_chunks_path`.
"""

def __init__(
Expand All @@ -58,6 +59,7 @@ def __init__(
max_hw: Tuple[Optional[int]] = (None, None),
np_chunks: bool = False,
np_chunks_path: Optional[str] = None,
use_existing_chunks: bool = False,
) -> None:
"""Initialize class attributes."""
super().__init__()
Expand All @@ -70,6 +72,7 @@ def __init__(
self.max_instances = get_max_instances(self.labels)
self.np_chunks = np_chunks
self.np_chunks_path = np_chunks_path
self.use_existing_chunks = use_existing_chunks
if self.np_chunks_path is None:
self.np_chunks_path = "."
path = (
Expand Down Expand Up @@ -188,6 +191,7 @@ class BottomUpDataset(BaseDataset):
np_chunks: If `True`, `.npz` chunks are generated and samples are loaded from
these chunks during training. Else, in-memory caching is used.
np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used.
use_existing_chunks: Use existing chunks in the `np_chunks_path`.
"""

def __init__(
Expand All @@ -201,6 +205,7 @@ def __init__(
max_hw: Tuple[Optional[int]] = (None, None),
np_chunks: bool = False,
np_chunks_path: Optional[str] = None,
use_existing_chunks: bool = False,
) -> None:
"""Initialize class attributes."""
super().__init__(
Expand All @@ -211,17 +216,19 @@ def __init__(
max_hw=max_hw,
np_chunks=np_chunks,
np_chunks_path=np_chunks_path,
use_existing_chunks=use_existing_chunks,
)
self.confmap_head_config = confmap_head_config
self.pafs_head_config = pafs_head_config

self.edge_inds = self.labels.skeletons[0].edge_inds
self._fill_cache()
if not self.use_existing_chunks:
self._fill_cache()

def __getitem__(self, index) -> Dict:
"""Return dict with image, confmaps and pafs for given index."""
if self.np_chunks:
ex = np.load(self.cache[index])
ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz")
sample = {}
for k, v in ex.items():
if k != "image":
Expand Down Expand Up @@ -297,6 +304,7 @@ class CenteredInstanceDataset(BaseDataset):
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
(required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ).
crop_hw: Height and width of the crop in pixels.
use_existing_chunks: Use existing chunks in the `np_chunks_path`.

Note: If scale is provided for centered-instance model, the images are cropped out
from the scaled image with the given crop size.
Expand All @@ -313,6 +321,7 @@ def __init__(
max_hw: Tuple[Optional[int]] = (None, None),
np_chunks: bool = False,
np_chunks_path: Optional[str] = None,
use_existing_chunks: bool = False,
) -> None:
"""Initialize class attributes."""
super().__init__(
Expand All @@ -323,12 +332,14 @@ def __init__(
max_hw=max_hw,
np_chunks=np_chunks,
np_chunks_path=np_chunks_path,
use_existing_chunks=use_existing_chunks,
)
self.crop_hw = crop_hw
self.confmap_head_config = confmap_head_config
self.instance_idx_list = self._get_instance_idx_list()
self.cache_lf = [None, None]
self._fill_cache()
if not self.use_existing_chunks:
self._fill_cache()

def _fill_cache(self):
"""Load all samples to cache."""
Expand Down Expand Up @@ -443,7 +454,7 @@ def __len__(self) -> int:
def __getitem__(self, index) -> Dict:
"""Return dict with cropped image and confmaps of instance for given index."""
if self.np_chunks:
ex = np.load(self.cache[index])
ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz")
sample = {}
for k, v in ex.items():
if k != "instance_image":
Expand Down Expand Up @@ -532,6 +543,7 @@ class CentroidDataset(BaseDataset):
np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used.
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
(required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ).
use_existing_chunks: Use existing chunks in the `np_chunks_path`.
"""

def __init__(
Expand All @@ -544,6 +556,7 @@ def __init__(
max_hw: Tuple[Optional[int]] = (None, None),
np_chunks: bool = False,
np_chunks_path: Optional[str] = None,
use_existing_chunks: bool = False,
) -> None:
"""Initialize class attributes."""
super().__init__(
Expand All @@ -554,9 +567,11 @@ def __init__(
max_hw=max_hw,
np_chunks=np_chunks,
np_chunks_path=np_chunks_path,
use_existing_chunks=use_existing_chunks,
)
self.confmap_head_config = confmap_head_config
self._fill_cache()
if not self.use_existing_chunks:
self._fill_cache()

def _fill_cache(self):
"""Load all samples to cache."""
Expand Down Expand Up @@ -624,7 +639,7 @@ def _fill_cache(self):
def __getitem__(self, index) -> Dict:
"""Return dict with image and confmaps for centroids for given index."""
if self.np_chunks:
ex = np.load(self.cache[index])
ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz")
sample = {}
for k, v in ex.items():
if k != "image":
Expand Down Expand Up @@ -688,6 +703,7 @@ class SingleInstanceDataset(BaseDataset):
np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used.
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
(required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ).
use_existing_chunks: Use existing chunks in the `np_chunks_path`.
"""

def __init__(
Expand All @@ -700,6 +716,7 @@ def __init__(
max_hw: Tuple[Optional[int]] = (None, None),
np_chunks: bool = False,
np_chunks_path: Optional[str] = None,
use_existing_chunks: bool = False,
) -> None:
"""Initialize class attributes."""
super().__init__(
Expand All @@ -710,14 +727,16 @@ def __init__(
max_hw=max_hw,
np_chunks=np_chunks,
np_chunks_path=np_chunks_path,
use_existing_chunks=use_existing_chunks,
)
self.confmap_head_config = confmap_head_config
self._fill_cache()
if not self.use_existing_chunks:
self._fill_cache()

def __getitem__(self, index) -> Dict:
"""Return dict with image and confmaps for instance for given index."""
if self.np_chunks:
ex = np.load(self.cache[index])
ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz")
sample = {}
for k, v in ex.items():
if k != "image":
Expand Down
39 changes: 37 additions & 2 deletions sleap_nn/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ class ModelTrainer:
(iii) trainer_config: trainer configs like accelerator, optimiser params.
data_pipeline_fw: Framework to create the data loaders. One of [`litdata`, `torch_dataset`, `torch_dataset_np_chunks`]
np_chunks_path: Path to save `.npz` chunks created with `torch_dataset_np_chunks` data pipeline framework.
use_existing_np_chunks: Use existing train and val chunks in the `np_chunks_path`.
"""

def __init__(
self,
config: OmegaConf,
data_pipeline_fw: str = "litdata",
np_chunks_path: Optional[str] = None,
use_existing_np_chunks: bool = False,
):
"""Initialise the class with configs and set the seed and device as class attributes."""
self.config = config
Expand All @@ -97,6 +99,24 @@ def __init__(
self.val_np_chunks_path = (
Path(np_chunks_path) / "val_chunks" if np_chunks_path is not None else None
)
self.use_existing_np_chunks = use_existing_np_chunks
if self.use_existing_np_chunks:
if not (
self.train_np_chunks_path.exists()
and self.train_np_chunks_path.is_dir()
and any(self.train_np_chunks_path.glob("*.npz"))
):
raise Exception(
f"There are no numpy chunks in the path: {self.train_np_chunks_path}"
)
if not (
self.val_np_chunks_path.exists()
and self.val_np_chunks_path.is_dir()
and any(self.val_np_chunks_path.glob("*.npz"))
):
raise Exception(
f"There are no numpy chunks in the path: {self.val_np_chunks_path}"
)
self.seed = self.config.trainer_config.seed
self.steps_per_epoch = self.config.trainer_config.steps_per_epoch

Expand Down Expand Up @@ -237,6 +257,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.train_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)
self.val_dataset = BottomUpDataset(
labels=val_labels,
Expand All @@ -248,6 +269,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.val_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)

elif self.model_type == "centered_instance":
Expand All @@ -261,6 +283,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.train_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)
self.val_dataset = CenteredInstanceDataset(
labels=val_labels,
Expand All @@ -272,6 +295,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.val_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)

elif self.model_type == "centroid":
Expand All @@ -284,6 +308,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.train_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)
self.val_dataset = CentroidDataset(
labels=val_labels,
Expand All @@ -294,6 +319,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.val_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)

elif self.model_type == "single_instance":
Expand All @@ -306,6 +332,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.train_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)
self.val_dataset = SingleInstanceDataset(
labels=val_labels,
Expand All @@ -316,6 +343,7 @@ def _create_data_loaders_torch_dataset(self):
max_hw=(self.max_height, self.max_width),
np_chunks=self.np_chunks,
np_chunks_path=self.val_np_chunks_path,
use_existing_chunks=self.use_existing_np_chunks,
)

else:
Expand All @@ -329,14 +357,21 @@ def _create_data_loaders_torch_dataset(self):
// self.config.trainer_config.train_data_loader.batch_size
)

pin_memory = (
self.config.trainer_config.train_data_loader.pin_memory
if "pin_memory" in self.config.trainer_config.train_data_loader
and self.config.trainer_config.train_data_loader.pin_memory is not None
else True
)

# train
self.train_data_loader = CyclerDataLoader(
dataset=self.train_dataset,
steps_per_epoch=self.steps_per_epoch,
shuffle=self.config.trainer_config.train_data_loader.shuffle,
batch_size=self.config.trainer_config.train_data_loader.batch_size,
num_workers=self.config.trainer_config.train_data_loader.num_workers,
pin_memory=True,
pin_memory=pin_memory,
persistent_workers=(
True
if self.config.trainer_config.train_data_loader.num_workers > 0
Expand All @@ -357,7 +392,7 @@ def _create_data_loaders_torch_dataset(self):
shuffle=False,
batch_size=self.config.trainer_config.val_data_loader.batch_size,
num_workers=self.config.trainer_config.val_data_loader.num_workers,
pin_memory=True,
pin_memory=pin_memory,
persistent_workers=(
True
if self.config.trainer_config.val_data_loader.num_workers > 0
Expand Down
21 changes: 21 additions & 0 deletions tests/training/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,27 @@ def test_trainer_torch_dataset(config, tmp_path: str):
model_trainer = ModelTrainer(config, data_pipeline_fw="torch_dataset")
assert model_trainer.dir_path == "."

##### test for reusing np chunks path
with pytest.raises(Exception):
model_trainer = ModelTrainer(
config,
data_pipeline_fw="torch_dataset_np_chunks",
np_chunks_path=tmp_path,
use_existing_np_chunks=True,
)

Path.mkdir(Path(tmp_path) / "train_chunks", parents=True)
file_path = Path(tmp_path) / "train_chunks" / "sample.npz"
np.savez_compressed(file_path, {1: 10})

with pytest.raises(Exception):
model_trainer = ModelTrainer(
config,
data_pipeline_fw="torch_dataset_np_chunks",
np_chunks_path=tmp_path,
use_existing_np_chunks=True,
)

#####

# # for topdown centered instance model
Expand Down