From 193444cff18777dbfbc8cbb3e424b13343de369d Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 6 Feb 2025 12:09:13 -0800 Subject: [PATCH 1/4] Add option to reuse np chunks --- sleap_nn/data/custom_datasets.py | 27 +++++++++++++++++++++++---- sleap_nn/training/model_trainer.py | 11 +++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sleap_nn/data/custom_datasets.py b/sleap_nn/data/custom_datasets.py index 278a87e1..bd2e286b 100644 --- a/sleap_nn/data/custom_datasets.py +++ b/sleap_nn/data/custom_datasets.py @@ -47,6 +47,7 @@ class BaseDataset(Dataset): np_chunks: If `True`, `.npz` chunks are generated and samples are loaded from these chunks during training. Else, in-memory caching is used. np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used. + use_existing_chunks: Use existing chunks in the `np_chunks_path`. """ def __init__( @@ -58,6 +59,7 @@ def __init__( max_hw: Tuple[Optional[int]] = (None, None), np_chunks: bool = False, np_chunks_path: Optional[str] = None, + use_existing_chunks: bool = False, ) -> None: """Initialize class attributes.""" super().__init__() @@ -70,6 +72,7 @@ def __init__( self.max_instances = get_max_instances(self.labels) self.np_chunks = np_chunks self.np_chunks_path = np_chunks_path + self.use_existing_chunks = use_existing_chunks if self.np_chunks_path is None: self.np_chunks_path = "." path = ( @@ -188,6 +191,7 @@ class BottomUpDataset(BaseDataset): np_chunks: If `True`, `.npz` chunks are generated and samples are loaded from these chunks during training. Else, in-memory caching is used. np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used. + use_existing_chunks: Use existing chunks in the `np_chunks_path`. """ def __init__( @@ -201,6 +205,7 @@ def __init__( max_hw: Tuple[Optional[int]] = (None, None), np_chunks: bool = False, np_chunks_path: Optional[str] = None, + use_existing_chunks: bool = False, ) -> None: """Initialize class attributes.""" super().__init__( @@ -211,12 +216,14 @@ def __init__( max_hw=max_hw, np_chunks=np_chunks, np_chunks_path=np_chunks_path, + use_existing_chunks=use_existing_chunks, ) self.confmap_head_config = confmap_head_config self.pafs_head_config = pafs_head_config self.edge_inds = self.labels.skeletons[0].edge_inds - self._fill_cache() + if not self.use_existing_chunks: + self._fill_cache() def __getitem__(self, index) -> Dict: """Return dict with image, confmaps and pafs for given index.""" @@ -297,6 +304,7 @@ class CenteredInstanceDataset(BaseDataset): confmap_head_config: DictConfig object with all the keys in the `head_config` section. (required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ). crop_hw: Height and width of the crop in pixels. + use_existing_chunks: Use existing chunks in the `np_chunks_path`. Note: If scale is provided for centered-instance model, the images are cropped out from the scaled image with the given crop size. @@ -313,6 +321,7 @@ def __init__( max_hw: Tuple[Optional[int]] = (None, None), np_chunks: bool = False, np_chunks_path: Optional[str] = None, + use_existing_chunks: bool = False, ) -> None: """Initialize class attributes.""" super().__init__( @@ -323,12 +332,14 @@ def __init__( max_hw=max_hw, np_chunks=np_chunks, np_chunks_path=np_chunks_path, + use_existing_chunks=use_existing_chunks, ) self.crop_hw = crop_hw self.confmap_head_config = confmap_head_config self.instance_idx_list = self._get_instance_idx_list() self.cache_lf = [None, None] - self._fill_cache() + if not self.use_existing_chunks: + self._fill_cache() def _fill_cache(self): """Load all samples to cache.""" @@ -532,6 +543,7 @@ class CentroidDataset(BaseDataset): np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used. confmap_head_config: DictConfig object with all the keys in the `head_config` section. (required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ). + use_existing_chunks: Use existing chunks in the `np_chunks_path`. """ def __init__( @@ -544,6 +556,7 @@ def __init__( max_hw: Tuple[Optional[int]] = (None, None), np_chunks: bool = False, np_chunks_path: Optional[str] = None, + use_existing_chunks: bool = False, ) -> None: """Initialize class attributes.""" super().__init__( @@ -554,9 +567,11 @@ def __init__( max_hw=max_hw, np_chunks=np_chunks, np_chunks_path=np_chunks_path, + use_existing_chunks=use_existing_chunks, ) self.confmap_head_config = confmap_head_config - self._fill_cache() + if not self.use_existing_chunks: + self._fill_cache() def _fill_cache(self): """Load all samples to cache.""" @@ -688,6 +703,7 @@ class SingleInstanceDataset(BaseDataset): np_chunks_path: Path to save the `.npz` chunks. If `None`, current working dir is used. confmap_head_config: DictConfig object with all the keys in the `head_config` section. (required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ). + use_existing_chunks: Use existing chunks in the `np_chunks_path`. """ def __init__( @@ -700,6 +716,7 @@ def __init__( max_hw: Tuple[Optional[int]] = (None, None), np_chunks: bool = False, np_chunks_path: Optional[str] = None, + use_existing_chunks: bool = False, ) -> None: """Initialize class attributes.""" super().__init__( @@ -710,9 +727,11 @@ def __init__( max_hw=max_hw, np_chunks=np_chunks, np_chunks_path=np_chunks_path, + use_existing_chunks=use_existing_chunks, ) self.confmap_head_config = confmap_head_config - self._fill_cache() + if not self.use_existing_chunks: + self._fill_cache() def __getitem__(self, index) -> Dict: """Return dict with image and confmaps for instance for given index.""" diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 24eed569..fce90855 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -77,6 +77,7 @@ class ModelTrainer: (iii) trainer_config: trainer configs like accelerator, optimiser params. data_pipeline_fw: Framework to create the data loaders. One of [`litdata`, `torch_dataset`, `torch_dataset_np_chunks`] np_chunks_path: Path to save `.npz` chunks created with `torch_dataset_np_chunks` data pipeline framework. + use_existing_chunks: Use existing train and val chunks in the `np_chunks_path`. """ def __init__( @@ -84,6 +85,7 @@ def __init__( config: OmegaConf, data_pipeline_fw: str = "litdata", np_chunks_path: Optional[str] = None, + use_existing_np_chunks: bool = False, ): """Initialise the class with configs and set the seed and device as class attributes.""" self.config = config @@ -97,6 +99,7 @@ def __init__( self.val_np_chunks_path = ( Path(np_chunks_path) / "val_chunks" if np_chunks_path is not None else None ) + self.use_existing_np_chunks = use_existing_np_chunks self.seed = self.config.trainer_config.seed self.steps_per_epoch = self.config.trainer_config.steps_per_epoch @@ -237,6 +240,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.train_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) self.val_dataset = BottomUpDataset( labels=val_labels, @@ -248,6 +252,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.val_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) elif self.model_type == "centered_instance": @@ -261,6 +266,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.train_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) self.val_dataset = CenteredInstanceDataset( labels=val_labels, @@ -272,6 +278,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.val_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) elif self.model_type == "centroid": @@ -284,6 +291,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.train_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) self.val_dataset = CentroidDataset( labels=val_labels, @@ -294,6 +302,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.val_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) elif self.model_type == "single_instance": @@ -306,6 +315,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.train_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) self.val_dataset = SingleInstanceDataset( labels=val_labels, @@ -316,6 +326,7 @@ def _create_data_loaders_torch_dataset(self): max_hw=(self.max_height, self.max_width), np_chunks=self.np_chunks, np_chunks_path=self.val_np_chunks_path, + use_existing_chunks=self.use_existing_np_chunks, ) else: From 25088dad4545643b5ba73438e048fb4d0d44fde6 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 6 Feb 2025 13:57:23 -0800 Subject: [PATCH 2/4] Raise exception if dir is empty --- sleap_nn/training/model_trainer.py | 28 ++++++++++++++++++++++++++-- tests/training/test_model_trainer.py | 9 +++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index fce90855..4504db98 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -100,6 +100,23 @@ def __init__( Path(np_chunks_path) / "val_chunks" if np_chunks_path is not None else None ) self.use_existing_np_chunks = use_existing_np_chunks + if self.use_existing_np_chunks: + if not ( + self.train_np_chunks_path.exists() + and self.train_np_chunks_path.is_dir() + and any(self.train_np_chunks_path.iterdir()) + ): + raise Exception( + f"There are no numpy chunks in the path: {self.train_np_chunks_path}" + ) + if not ( + self.val_np_chunks_path.exists() + and self.val_np_chunks_path.is_dir() + and any(self.val_np_chunks_path.iterdir()) + ): + raise Exception( + f"There are no numpy chunks in the path: {self.val_np_chunks_path}" + ) self.seed = self.config.trainer_config.seed self.steps_per_epoch = self.config.trainer_config.steps_per_epoch @@ -340,6 +357,13 @@ def _create_data_loaders_torch_dataset(self): // self.config.trainer_config.train_data_loader.batch_size ) + pin_memory = ( + self.config.trainer_config.train_data_loader.pin_memory + if "pin_memory" in self.config.trainer_config.train_data_loader + and self.config.trainer_config.train_data_loader.pin_memory is not None + else True + ) + # train self.train_data_loader = CyclerDataLoader( dataset=self.train_dataset, @@ -347,7 +371,7 @@ def _create_data_loaders_torch_dataset(self): shuffle=self.config.trainer_config.train_data_loader.shuffle, batch_size=self.config.trainer_config.train_data_loader.batch_size, num_workers=self.config.trainer_config.train_data_loader.num_workers, - pin_memory=True, + pin_memory=pin_memory, persistent_workers=( True if self.config.trainer_config.train_data_loader.num_workers > 0 @@ -368,7 +392,7 @@ def _create_data_loaders_torch_dataset(self): shuffle=False, batch_size=self.config.trainer_config.val_data_loader.batch_size, num_workers=self.config.trainer_config.val_data_loader.num_workers, - pin_memory=True, + pin_memory=pin_memory, persistent_workers=( True if self.config.trainer_config.val_data_loader.num_workers > 0 diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 7e10c412..c722b97b 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -327,6 +327,15 @@ def test_trainer_torch_dataset(config, tmp_path: str): model_trainer = ModelTrainer(config, data_pipeline_fw="torch_dataset") assert model_trainer.dir_path == "." + ##### test for reusing np chunks path + with pytest.raises(Exception): + model_trainer = ModelTrainer( + config, + data_pipeline_fw="torch_dataset_np_chunks", + np_chunks_path=tmp_path, + use_existing_np_chunks=True, + ) + ##### # # for topdown centered instance model From 7907febcffb55021e56789bc03b0dc5f39b6723e Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 7 Feb 2025 14:21:02 -0800 Subject: [PATCH 3/4] Fix filename access --- sleap_nn/data/custom_datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sleap_nn/data/custom_datasets.py b/sleap_nn/data/custom_datasets.py index bd2e286b..45516af9 100644 --- a/sleap_nn/data/custom_datasets.py +++ b/sleap_nn/data/custom_datasets.py @@ -228,7 +228,7 @@ def __init__( def __getitem__(self, index) -> Dict: """Return dict with image, confmaps and pafs for given index.""" if self.np_chunks: - ex = np.load(self.cache[index]) + ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz") sample = {} for k, v in ex.items(): if k != "image": @@ -454,7 +454,7 @@ def __len__(self) -> int: def __getitem__(self, index) -> Dict: """Return dict with cropped image and confmaps of instance for given index.""" if self.np_chunks: - ex = np.load(self.cache[index]) + ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz") sample = {} for k, v in ex.items(): if k != "instance_image": @@ -639,7 +639,7 @@ def _fill_cache(self): def __getitem__(self, index) -> Dict: """Return dict with image and confmaps for centroids for given index.""" if self.np_chunks: - ex = np.load(self.cache[index]) + ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz") sample = {} for k, v in ex.items(): if k != "image": @@ -736,7 +736,7 @@ def __init__( def __getitem__(self, index) -> Dict: """Return dict with image and confmaps for instance for given index.""" if self.np_chunks: - ex = np.load(self.cache[index]) + ex = np.load(f"{self.np_chunks_path}/sample_{index}.npz") sample = {} for k, v in ex.items(): if k != "image": From 216bcf1fd28c8529f0f8d972d46fd8fb7163df0a Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 11 Feb 2025 15:57:54 -0800 Subject: [PATCH 4/4] Add test cases --- sleap_nn/training/model_trainer.py | 6 +++--- tests/training/test_model_trainer.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 4504db98..df036af2 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -77,7 +77,7 @@ class ModelTrainer: (iii) trainer_config: trainer configs like accelerator, optimiser params. data_pipeline_fw: Framework to create the data loaders. One of [`litdata`, `torch_dataset`, `torch_dataset_np_chunks`] np_chunks_path: Path to save `.npz` chunks created with `torch_dataset_np_chunks` data pipeline framework. - use_existing_chunks: Use existing train and val chunks in the `np_chunks_path`. + use_existing_np_chunks: Use existing train and val chunks in the `np_chunks_path`. """ def __init__( @@ -104,7 +104,7 @@ def __init__( if not ( self.train_np_chunks_path.exists() and self.train_np_chunks_path.is_dir() - and any(self.train_np_chunks_path.iterdir()) + and any(self.train_np_chunks_path.glob("*.npz")) ): raise Exception( f"There are no numpy chunks in the path: {self.train_np_chunks_path}" @@ -112,7 +112,7 @@ def __init__( if not ( self.val_np_chunks_path.exists() and self.val_np_chunks_path.is_dir() - and any(self.val_np_chunks_path.iterdir()) + and any(self.val_np_chunks_path.glob("*.npz")) ): raise Exception( f"There are no numpy chunks in the path: {self.val_np_chunks_path}" diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index c722b97b..ce307f75 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -328,6 +328,18 @@ def test_trainer_torch_dataset(config, tmp_path: str): assert model_trainer.dir_path == "." ##### test for reusing np chunks path + with pytest.raises(Exception): + model_trainer = ModelTrainer( + config, + data_pipeline_fw="torch_dataset_np_chunks", + np_chunks_path=tmp_path, + use_existing_np_chunks=True, + ) + + Path.mkdir(Path(tmp_path) / "train_chunks", parents=True) + file_path = Path(tmp_path) / "train_chunks" / "sample.npz" + np.savez_compressed(file_path, {1: 10}) + with pytest.raises(Exception): model_trainer = ModelTrainer( config,