-
Couldn't load subscription status.
- Fork 6
Add threshold mode in config for learning rate scheduler #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe pull request introduces several modifications to configuration files and the model trainer implementation. Key changes include the addition of a Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? πͺ§ TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caution
Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments.
Actionable comments posted: 27
π§Ή Outside diff range and nitpick comments (16)
docs/config_bottomup.yaml (1)
Line range hint
1-107: Configuration structure maintains good organization and clarity.The removal of commented-out alternative backbone configurations (ConvNeXt and Swin) improves maintainability while preserving the clear hierarchical structure of the configuration.
Consider maintaining alternative backbone configurations in separate example files if they might be useful for reference, rather than as comments in the main configuration.
docs/config_topdown_centered_instance.yaml (1)
105-105: LGTM! Consider adding parameter documentation.The
threshold_modeparameter is correctly placed in the lr_scheduler configuration with a valid value of 'abs'.Consider adding a comment to document that this parameter accepts either 'abs' or 'rel' values and their implications on the learning rate scheduling behavior.
threshold: 1.0e-07 + # threshold_mode: Mode for interpreting threshold parameter ('abs' for absolute, 'rel' for relative change) threshold_mode: absinitial_config.yaml (1)
68-69: Improve empty path handlingEmpty string defaults for paths could lead to unclear errors. Consider using null as default or providing clear placeholder paths.
save_ckpt: false - save_ckpt_path: '' + save_ckpt_path: null # Set to "./checkpoints" when save_ckpt is truedocs/config_centroid.yaml (1)
136-136: Document the valid values for threshold_mode.While the default value
absis set correctly, please add a comment documenting that the valid values are 'abs' (absolute) or 'rel' (relative) for thethreshold_modeparameter.Apply this diff to add documentation:
threshold: 1.0e-07 - threshold_mode: abs + threshold_mode: abs # Valid values: 'abs' (absolute) or 'rel' (relative) cooldown: 3tests/fixtures/datasets.py (1)
Line range hint
156-164: Add missing threshold_mode parameter in lr_scheduler configThe PR's main objective is to add the
threshold_modeparameter to the learning rate scheduler configuration, but it's missing in this test fixture. This could lead to inconsistency with other configuration files that have been updated.Apply this diff to align with other config files:
"lr_scheduler": { "threshold": 1e-07, + "threshold_mode": "rel", "cooldown": 3, "patience": 5, "factor": 0.5, "min_lr": 1e-08, },sleap_nn/training/get_bin_files.py (1)
1-1: Enhance the module docstring with more details.The current docstring is too brief. Consider adding more details about:
- The purpose and use cases of the binary files
- Required command-line arguments
- Supported model types
- Example usage
Here's a suggested improvement:
-"""Function to generate `.bin` files.""" +"""Generate optimized binary files for model training. + +This script processes training and validation data to generate optimized .bin files +for different model types (single_instance, centered_instance, centroid, bottomup). + +Example usage: + python get_bin_files.py --dir_path /path/to/data \ + --model_type single_instance \ + --num_workers 4 \ + --chunk_size 1000 + +Required arguments: + --dir_path: Directory containing initial_config.yaml + --model_type: Type of model (single_instance, centered_instance, centroid, bottomup) + --num_workers: Number of worker processes + --chunk_size: Size of data chunks + --user_instances_only: Filter for user instances + --crop_hw: Crop height/width (required for centered_instance) +"""tests/data/test_get_data_chunks.py (1)
Line range hint
1-261: Consider reducing test duplication with fixturesThere's significant duplication in the transform initialization and shape validation across all test functions. Consider creating pytest fixtures to reduce this duplication.
Example refactor:
import pytest @pytest.fixture def transform(): return T.ToTensor() @pytest.fixture def max_hw(labels): return get_max_height_width(labels) def assert_image_shapes(sample, channels, height, width): """Helper to validate image shapes consistently.""" assert transform(sample["image"]).shape == (channels, height, width)sleap_nn/data/instance_cropping.py (1)
50-55: LGTM! Good defensive programming practice.The added NaN handling improves the robustness of the crop size calculation. The implementation correctly prevents NaN propagation while maintaining the function's contract.
Consider extracting the NaN handling into a helper function for better readability:
+def safe_diff(max_val: float, min_val: float) -> float: + """Calculate difference with NaN handling.""" + diff = max_val - min_val + return 0 if np.isnan(diff) else diff def find_instance_crop_size(...): # ... for lf in labels: for inst in lf.instances: pts = inst.numpy() pts *= input_scaling - diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]) - diff_x = 0 if np.isnan(diff_x) else diff_x + diff_x = safe_diff(np.nanmax(pts[:, 0]), np.nanmin(pts[:, 0])) max_length = np.maximum(max_length, diff_x) - diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]) - diff_y = 0 if np.isnan(diff_y) else diff_y + diff_y = safe_diff(np.nanmax(pts[:, 1]), np.nanmin(pts[:, 1])) max_length = np.maximum(max_length, diff_y)sleap_nn/data/get_data_chunks.py (1)
247-251: Minor: Maintain consistent documentation formatting.There's an extra period after "config file" in the documentation that isn't present in other similar docstrings.
- data_config: Data-related configuration. (`data_config` section in the config file). + data_config: Data-related configuration. (`data_config` section in the config file)tests/inference/test_topdown.py (1)
216-216: Consider adding test cases with differenteff_scalevalues.The
eff_scaletensor is set to 1.0 in all test cases. Consider adding test cases with different scale values to verify the model's behavior with various scaling factors.sleap_nn/data/providers.py (1)
31-35: Add error handling and improve documentation.While the implementation is efficient, consider these improvements:
- Add error handling for empty labels/videos
- Enhance docstring with Args/Returns sections and edge cases
Consider this improved implementation:
def get_max_height_width(labels: sio.Labels) -> Tuple[int, int]: - """Return `(height, width)` that is the maximum of all videos.""" + """Calculate the maximum height and width across all videos in the labels. + + Args: + labels: sleap_io.Labels object containing video data + + Returns: + Tuple[int, int]: Maximum (height, width) across all videos + + Raises: + ValueError: If labels contains no videos + """ + if not labels.videos: + raise ValueError("Labels object contains no videos") return max(video.shape[1] for video in labels.videos), max( video.shape[2] for video in labels.videos )tests/training/test_model_trainer.py (1)
Line range hint
1-507: Add test coverage for the new threshold_mode parameter.The PR introduces a new
threshold_modeparameter for the learning rate scheduler, but there are no test cases verifying this functionality. Please add test cases to:
- Verify both
absandrelthreshold modes work as expected- Ensure proper error handling for invalid threshold mode values
Here's a suggested test case structure:
def test_lr_scheduler_threshold_mode(config, tmp_path: str): """Test learning rate scheduler with different threshold modes.""" # Test absolute threshold mode config_abs = config.copy() OmegaConf.update(config_abs, "trainer_config.lr_scheduler.threshold_mode", "abs") trainer_abs = ModelTrainer(config_abs) # Add assertions to verify absolute threshold behavior # Test relative threshold mode config_rel = config.copy() OmegaConf.update(config_rel, "trainer_config.lr_scheduler.threshold_mode", "rel") trainer_rel = ModelTrainer(config_rel) # Add assertions to verify relative threshold behavior # Test invalid threshold mode config_invalid = config.copy() OmegaConf.update(config_invalid, "trainer_config.lr_scheduler.threshold_mode", "invalid") with pytest.raises(ValueError, match="threshold_mode must be 'rel' or 'abs'"): trainer_invalid = ModelTrainer(config_invalid)sleap_nn/data/resizing.py (2)
111-112: Clarify the function's behavior in the docstringThe docstring states: "Apply scaling and padding to smaller image to (max_height, max_width) shape." However, the function raises an error if the image dimensions exceed the maximum dimensions. Consider updating the docstring to accurately reflect that images larger than the maximum dimensions will cause an error.
150-152: Return consistenteff_scale_ratiofor all casesCurrently, when the image dimensions match the maximum dimensions, the function returns
eff_scale_ratioas1.0. Ensure that this value accurately reflects any scaling applied, including when no scaling is needed.sleap_nn/training/model_trainer.py (1)
197-199: Use logging instead of printing to standard outputDirectly printing to standard output may not be suitable for libraries or production code. Consider using the
loggingmodule for better control over logging levels and outputs.Apply this diff to replace
+ import logging ... - print("Standard Output:\n", stdout) - print("Standard Error:\n", stderr) + logging.info(f"Standard Output:\n{stdout}") + logging.error(f"Standard Error:\n{stderr}")Don't forget to configure the logging level and handlers as appropriate for your application.
sleap_nn/inference/predictors.py (1)
1623-1624: Consider broader support foranchor_indIn the
mainfunction,anchor_indis assigned to the predictor only if it's an instance ofTopDownPredictor. If other predictor types could benefit fromanchor_ind, consider implementing a more generic approach or documenting why it's specific toTopDownPredictor.
π Comments failed to post (27)
docs/config_bottomup.yaml (1)
107-107: π‘ Codebase verification
The
threshold_modeparameter needs to be added to test configuration filesThe verification reveals an inconsistency in the configuration files. While the main configuration files in the
docs/directory (config_bottomup.yaml,config_centroid.yaml,config_topdown_centered_instance.yaml) correctly include thethreshold_mode: absparameter, the test configuration files intests/assets/and the rootinitial_config.yamlare missing this parameter.Files requiring updates:
tests/assets/minimal_instance_centroid/training_config.yamltests/assets/minimal_instance_centroid/initial_config.yamltests/assets/minimal_instance_bottomup/training_config.yamltests/assets/minimal_instance_bottomup/initial_config.yamltests/assets/minimal_instance/training_config.yamltests/assets/minimal_instance/initial_config.yamlinitial_config.yamlπ Analysis chain
LGTM! The threshold_mode parameter is correctly configured.
The addition of
threshold_mode: absis properly placed under thelr_schedulersection and works in conjunction with the existingthresholdparameter. This enhancement allows for more flexible learning rate scheduling by supporting both absolute and relative threshold comparisons.Let's verify if other configuration files have consistent threshold_mode settings:
π Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for consistent lr_scheduler threshold_mode across config files # Expected: All config files should have the same structure for lr_scheduler # Search for lr_scheduler configurations in all yaml files fd -e yaml -e yml | xargs rg -A 10 "lr_scheduler:"Length of output: 6627
initial_config.yaml (4)
1-93:
β οΈ Potential issueFix inconsistent newline characters
The file uses incorrect newline characters. Ensure consistent use of Unix-style newlines (\n) throughout the file.
Use a text editor that supports Unix-style line endings or run:
dos2unix initial_config.yamlπ§° Tools
πͺ yamllint
[error] 1-1: wrong new line character: expected \n
(new-lines)
88-93:
β οΈ Potential issueAdd missing
threshold_modeparameter to lr_schedulerAccording to the PR objectives, a new
threshold_modeparameter should be added to the learning rate scheduler configuration to support both 'abs' and 'rel' modes.Add the parameter with a default value:
lr_scheduler: threshold: 1.0e-07 + threshold_mode: 'rel' # Options: 'rel' or 'abs' cooldown: 3 patience: 5 factor: 0.5 min_lr: 1.0e-08π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.lr_scheduler: threshold: 1.0e-07 threshold_mode: 'rel' # Options: 'rel' or 'abs' cooldown: 3 patience: 5 factor: 0.5 min_lr: 1.0e-08
76-76:
β οΈ Potential issueSecurity: Remove API key field from version control
Even though it's empty, the
api_keyfield shouldn't be in version control. Consider moving sensitive fields to environment variables or a separate secrets file.Replace with environment variable:
- api_key: '' + api_key: ${WANDB_API_KEY}π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.api_key: ${WANDB_API_KEY}
3-4:
β οΈ Potential issueReplace hardcoded Windows paths with environment-agnostic paths
The configuration contains hardcoded Windows paths pointing to test assets. This makes the configuration non-portable and environment-specific.
Consider:
- Using relative paths or environment variables
- Using forward slashes for cross-platform compatibility
- Providing example paths that don't reference test assets
- train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp - val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp + train_labels_path: ${DATA_DIR}/train/labels.slp + val_labels_path: ${DATA_DIR}/val/labels.slpCommittable suggestion was skipped due to low confidence.
sleap_nn/inference/single_instance.py (1)
87-89:
β οΈ Potential issueAdd input validation and documentation for eff_scale.
While the scaling implementation looks correct, there are a few improvements needed:
- Add validation for the existence and validity of
inputs["eff_scale"]- Document the expected shape and format of
eff_scalein the method's docstring- Consider handling potential division by zero
Here's a suggested implementation:
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Predict confidence maps and infer peak coordinates. Args: inputs: Dictionary with "image" as one of the keys. + "eff_scale": Tensor of shape (batch_size,) containing + effective scale factors for each sample. Returns: A dictionary of outputs with keys: `"pred_instance_peaks"`: The predicted peaks for each instance in the batch as a `torch.Tensor` of shape `(samples, nodes, 2)`. `"pred_peak_vals"`: The value of the confidence maps at the predicted peaks for each instance in the batch as a `torch.Tensor` of shape `(samples, nodes)`. + + Raises: + KeyError: If "eff_scale" is missing from inputs + ValueError: If "eff_scale" contains zero values """ # Network forward pass. cms = self.torch_model(inputs["image"]) ... + if "eff_scale" not in inputs: + raise KeyError("Missing required input: eff_scale") + + eff_scale = inputs["eff_scale"] + if torch.any(eff_scale == 0): + raise ValueError("eff_scale contains zero values") + peak_points = peak_points / ( inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2) ).to(peak_points.device)Committable suggestion was skipped due to low confidence.
sleap_nn/training/get_bin_files.py (5)
28-38:
β οΈ Potential issueAdd validation for configuration and arguments.
The initialization lacks validation for:
- Required configuration fields
- Crop dimensions when using centered_instance model
- Existence of train/val label paths
Add validation code after loading the config:
def validate_config(config, args): """Validate configuration and arguments.""" required_fields = [ "data_config.train_labels_path", "data_config.val_labels_path", "model_config.backbone_config.max_stride", ] for field in required_fields: if not OmegaConf.select(config, field): raise ValueError(f"Missing required config field: {field}") if args.model_type == "centered_instance" and args.crop_hw is None: raise ValueError("crop_hw is required for centered_instance model type") for path_field in ["train_labels_path", "val_labels_path"]: path = Path(OmegaConf.select(config.data_config, path_field)) if not path.exists(): raise FileNotFoundError(f"Labels file not found: {path}") # Add after loading config validate_config(config, args)π§° Tools
πͺ Ruff
33-33: Use
not ...instead ofFalse if ... else TrueReplace with
not ...(SIM211)
40-143: π οΈ Refactor suggestion
Refactor repeated optimization code and add error handling.
The code has several areas for improvement:
- Repeated
ld.optimizecalls could be extracted into a helper function- Missing progress tracking
- No error handling for the optimization process
Here's a suggested refactoring:
def process_dataset(factory_get_chunks, labels, output_dir, num_workers, chunk_size): """Process a dataset using the given chunk factory.""" try: print(f"Processing {len(labels)} samples to {output_dir}...") inputs = [(x, labels.videos.index(x.video)) for x in labels] ld.optimize( fn=factory_get_chunks, inputs=inputs, output_dir=output_dir, num_workers=num_workers, chunk_size=chunk_size, ) print(f"Successfully processed {len(inputs)} samples") except Exception as e: print(f"Error processing dataset: {str(e)}") raise # Usage example for single_instance model type: if args.model_type == "single_instance": factory_get_chunks = functools.partial( single_instance_data_chunks, data_config=config.data_config, user_instances_only=user_instances_only, max_hw=(max_height, max_width), ) for dataset_name, labels in [("train", train_labels), ("val", val_labels)]: output_dir = (Path(args.dir_path) / f"{dataset_name}_chunks").as_posix() process_dataset( factory_get_chunks, labels, output_dir, args.num_workers, args.chunk_size, )
144-147: π οΈ Refactor suggestion
Consider using an enum for model types and validating earlier.
Moving the model type validation earlier and using an enum would improve maintainability.
Here's a suggested improvement:
from enum import Enum, auto class ModelType(Enum): SINGLE_INSTANCE = "single_instance" CENTERED_INSTANCE = "centered_instance" CENTROID = "centroid" BOTTOMUP = "bottomup" @classmethod def from_str(cls, value: str) -> "ModelType": try: return cls(value) except ValueError: raise ValueError( f"{value} is not defined. Please choose one of: " f"{', '.join(m.value for m in cls)}" ) # In argument parser: parser.add_argument( "--model_type", type=ModelType.from_str, required=True, help=f"Type of model. Choices: {', '.join(m.value for m in ModelType)}", ) # Then use args.model_type.value in the if-elif chain
33-33:
β οΈ Potential issueSimplify the boolean conversion.
The current boolean conversion is unnecessarily complex.
- user_instances_only = False if args.user_instances_only == 0 else True + user_instances_only = bool(args.user_instances_only)π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.user_instances_only = bool(args.user_instances_only)π§° Tools
πͺ Ruff
33-33: Use
not ...instead ofFalse if ... else TrueReplace with
not ...(SIM211)
19-26:
β οΈ Potential issueImprove argument parsing with help messages and proper types.
The current argument parsing could be enhanced with:
- Help messages for each argument
- Proper type validation
- Required/optional argument specification
- Default values where appropriate
Here's a suggested improvement:
- parser = argparse.ArgumentParser() - parser.add_argument("--dir_path", type=str) - parser.add_argument("--user_instances_only", type=str) - parser.add_argument("--model_type", type=str) - parser.add_argument("--num_workers", type=int) - parser.add_argument("--chunk_size", type=int) - parser.add_argument("--crop_hw", type=int, default=None) + parser = argparse.ArgumentParser(description="Generate optimized binary files for model training.") + parser.add_argument( + "--dir_path", + type=Path, + required=True, + help="Directory containing initial_config.yaml", + ) + parser.add_argument( + "--model_type", + type=str, + required=True, + choices=["single_instance", "centered_instance", "centroid", "bottomup"], + help="Type of model to generate data for", + ) + parser.add_argument( + "--num_workers", + type=int, + required=True, + help="Number of worker processes", + ) + parser.add_argument( + "--chunk_size", + type=int, + required=True, + help="Size of data chunks", + ) + parser.add_argument( + "--user_instances_only", + type=bool, + default=False, + help="Filter for user instances only", + ) + parser.add_argument( + "--crop_hw", + type=int, + default=None, + help="Crop height/width (required for centered_instance)", + )π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.parser = argparse.ArgumentParser(description="Generate optimized binary files for model training.") parser.add_argument( "--dir_path", type=Path, required=True, help="Directory containing initial_config.yaml", ) parser.add_argument( "--model_type", type=str, required=True, choices=["single_instance", "centered_instance", "centroid", "bottomup"], help="Type of model to generate data for", ) parser.add_argument( "--num_workers", type=int, required=True, help="Number of worker processes", ) parser.add_argument( "--chunk_size", type=int, required=True, help="Size of data chunks", ) parser.add_argument( "--user_instances_only", type=bool, default=False, help="Filter for user instances only", ) parser.add_argument( "--crop_hw", type=int, default=None, help="Crop height/width (required for centered_instance)", ) args = parser.parse_args()tests/data/test_streaming_datasets.py (1)
93-93: π οΈ Refactor suggestion
Consider consolidating duplicate test setup code.
The pattern of calculating
max_hwand setting up the partial function is repeated across all test functions. Consider extracting this into a helper function to reduce code duplication.+def setup_test_data_chunks(labels, dir_path, chunk_func, **kwargs): + """Helper function to set up test data chunks. + + Args: + labels: Labels object + dir_path: Output directory path + chunk_func: Data chunk function to use + **kwargs: Additional arguments for chunk function + """ + max_hw = get_max_height_width(labels) + partial_func = functools.partial( + chunk_func, + data_config=config.data_config, + max_hw=max_hw, + **kwargs + ) + ld.optimize( + fn=partial_func, + inputs=[(x, labels.videos.index(x.video)) for x in labels], + output_dir=str(dir_path), + chunk_size=4, + ) + return max_hw def test_centered_instance_streaming_dataset(minimal_instance, sleap_data_dir, config): """Test CenteredInstanceStreamingDataset class.""" labels = sio.load_slp(minimal_instance) - max_hw = get_max_height_width(labels) dir_path = Path(sleap_data_dir) / "data_chunks" - partial_func = functools.partial( - centered_instance_data_chunks, - data_config=config.data_config, - max_instances=2, - crop_size=(160, 160), - anchor_ind=0, - max_hw=max_hw, - ) - ld.optimize( - fn=partial_func, - inputs=[(x, labels.videos.index(x.video)) for x in labels], - output_dir=str(dir_path), - chunk_size=4, - ) + setup_test_data_chunks( + labels, + dir_path, + centered_instance_data_chunks, + max_instances=2, + crop_size=(160, 160), + anchor_ind=0 + )Also applies to: 103-103
sleap_nn/data/get_data_chunks.py (1)
150-152: π οΈ Refactor suggestion
Consider standardizing the normalization step.
The normalization is applied inside
generate_crops, while other functions handle normalization differently. Consider standardizing where normalization occurs in the processing pipeline to maintain consistency.Consider moving the normalization step outside
generate_crops:- res = generate_crops( - apply_normalization(sample["image"]), instance, centroid, crop_size - ) + normalized_image = apply_normalization(sample["image"]) + res = generate_crops(normalized_image, instance, centroid, crop_size)π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.normalized_image = apply_normalization(sample["image"]) res = generate_crops(normalized_image, instance, centroid, crop_size)tests/inference/test_topdown.py (1)
139-146: π οΈ Refactor suggestion
Consider extracting duplicated example preparation code into a helper function.
The example preparation code is duplicated across multiple test functions. Consider extracting it into a helper function to improve maintainability and reduce code duplication.
def prepare_test_example(labels, frame_idx=0, stride=2): """Helper function to prepare test example with common fields.""" ex = process_lf(labels[frame_idx], frame_idx, stride) ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0) ex["instances"] = ex["instances"].unsqueeze(dim=0) ex["frame_idx"] = ex["frame_idx"].unsqueeze(dim=0) ex["video_idx"] = ex["video_idx"].unsqueeze(dim=0) ex["orig_size"] = ex["orig_size"].unsqueeze(dim=0) ex["eff_scale"] = torch.Tensor([1.0]) return exsleap_nn/data/streaming_datasets.py (2)
70-73: π οΈ Refactor suggestion
Consider optimizing the image transformation pipeline.
The transformation steps are consistent across all dataset classes but could be improved:
- Error handling: The code assumes the input is always a valid PIL image. Consider adding validation.
- Performance: These operations could be combined into a single transform using
torch.nn.Sequential.- Memory efficiency: The operations create multiple intermediate tensors.
Consider creating a reusable transform pipeline to reduce code duplication and improve efficiency:
class ImageTransformPipeline: def __init__(self): self.transform = T.Compose([ T.PILToTensor(), lambda x: x.unsqueeze(0), apply_normalization ]) def __call__(self, image): if not hasattr(image, 'mode'): # Basic PIL image check raise ValueError("Expected PIL image") return self.transform(image)Then use it in each dataset:
- transform = T.PILToTensor() - ex["image"] = transform(ex["image"]) - ex["image"] = ex["image"].unsqueeze(dim=0) - ex["image"] = apply_normalization(ex["image"]) + transform = ImageTransformPipeline() + ex["image"] = transform(ex["image"])Also applies to: 171-174, 272-275, 358-361
70-73: π‘ Codebase verification
After analyzing the test file content, I can now generate the final response.
Test coverage exists but needs expansion for new transformation pipeline
The existing tests already verify image shapes and basic transformations, but they don't explicitly test the new PILToTensor transformation and normalization steps. Current test coverage:
- β Image shape verification after transformations
- β Augmentation pipeline (random crop)
- β PILToTensor conversion
- β Normalization validation
- β Error handling for invalid inputs
Suggested test additions:
- Add test cases in
tests/data/test_streaming_datasets.pyfor each dataset class to verify:
- Correct conversion from PIL to tensor format
- Proper normalization of pixel values (range and distribution)
- Error handling for invalid image inputs
π Analysis chain
Ensure test coverage for the new transformation pipeline.
The addition of image transformation steps across all dataset classes requires corresponding test updates to verify:
- Correct tensor shape and type after transformation
- Proper normalization of pixel values
- Error handling for invalid inputs
- Memory usage patterns
Run the following to check existing test coverage:
Also applies to: 171-174, 272-275, 358-361
Let me gather more information about the test coverage by examining the test file content and checking for any existing image-related tests.
π Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Look for related test files fd "test.*streaming_datasets.*\.py" # Check for existing image transformation tests rg -A 5 "def test.*transform" $(fd "test.*streaming_datasets.*\.py")Length of output: 390
Script:
#!/bin/bash # Check content of the test file cat tests/data/test_streaming_datasets.py # Look for any image-related test cases rg -i "test.*image" tests/data/test_streaming_datasets.py # Look for normalization tests rg -i "test.*normal" tests/data/test_streaming_datasets.pyLength of output: 8612
sleap_nn/data/resizing.py (2)
118-122: π οΈ Refactor suggestion
Consider resizing images larger than max dimensions instead of raising an error
Currently, the function raises a
ValueErrorwhen the image dimensions exceed the maximum allowed dimensions. To enhance the function's robustness, consider automatically scaling down images larger than the maximum dimensions to fit within the desired size.Apply this diff to modify the function:
def apply_sizematcher( image: torch.Tensor, max_height: Optional[int] = None, max_width: Optional[int] = None, ): """Apply scaling and padding to image to (max_height, max_width) shape.""" img_height, img_width = image.shape[-2:] # Set default max dimensions if max_height is None: max_height = img_height if max_width is None: max_width = img_width - if img_height > max_height: - raise ValueError( - f"Max height {max_height} should be greater than the current image height: {img_height}" - ) - - if img_width > max_width: - raise ValueError( - f"Max width {max_width} should be greater than the current image width: {img_width}" - ) - - if img_height < max_height or img_width < max_width: + if img_height != max_height or img_width != max_width: hratio = max_height / img_height wratio = max_width / img_width if hratio > wratio: eff_scale_ratio = wratio target_h = int(round(img_height * wratio)) target_w = int(round(img_width * wratio)) else: eff_scale_ratio = hratio target_w = int(round(img_width * hratio)) target_h = int(round(img_height * hratio)) image = tvf.resize(image, size=(target_h, target_w)) pad_height = max_height - target_h pad_width = max_width - target_w image = F.pad( image, (0, pad_width, 0, pad_height), mode="constant", ).to(torch.float32) return image, eff_scale_ratio else: return image, 1.0π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if img_height != max_height or img_width != max_width: hratio = max_height / img_height wratio = max_width / img_width if hratio > wratio: eff_scale_ratio = wratio target_h = int(round(img_height * wratio)) target_w = int(round(img_width * wratio)) else: eff_scale_ratio = hratio target_w = int(round(img_width * hratio)) target_h = int(round(img_height * hratio)) image = tvf.resize(image, size=(target_h, target_w)) pad_height = max_height - target_h pad_width = max_width - target_w image = F.pad( image, (0, pad_width, 0, pad_height), mode="constant", ).to(torch.float32) return image, eff_scale_ratio else: return image, 1.0
126-152:
β οΈ Potential issueEnsure consistent scaling and padding logic
The current logic only scales and pads images that are smaller than the maximum dimensions. By modifying the condition to include images that are larger, the function will handle all image sizes consistently, scaling and padding as needed.
Apply this diff to adjust the condition:
- if img_height < max_height or img_width < max_width: + if img_height != max_height or img_width != max_width: hratio = max_height / img_height wratio = max_width / img_width # Remaining code unchanged...π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if img_height != max_height or img_width != max_width: hratio = max_height / img_height wratio = max_width / img_width if hratio > wratio: eff_scale_ratio = wratio target_h = int(round(img_height * wratio)) target_w = int(round(img_width * wratio)) else: eff_scale_ratio = hratio target_w = int(round(img_width * hratio)) target_h = int(round(img_height * hratio)) image = tvf.resize(image, size=(target_h, target_w)) pad_height = max_height - target_h pad_width = max_width - target_w image = F.pad( image, (0, pad_width, 0, pad_height), mode="constant", ).to(torch.float32) return image, eff_scale_ratio else: return image, 1.0sleap_nn/inference/topdown.py (1)
69-69:
β οΈ Potential issuePotential issue with
crop_hwdefaulting toNoneBy changing
crop_hwto default toNone, ensure that the code handles cases wherecrop_hwisNone. In the_generate_cropsmethod, accessingself.crop_hw[0]andself.crop_hw[1]will raise an exception ifcrop_hwisNone.Apply this diff to set a default value when
crop_hwisNone:def __init__( self, torch_model: Optional[L.LightningModule] = None, output_stride: int = 1, peak_threshold: float = 0.0, max_instances: Optional[int] = None, refinement: Optional[str] = None, integral_patch_size: int = 5, return_confmaps: bool = False, return_crops: bool = False, crop_hw: Optional[List[int]] = None, input_scale: float = 1.0, max_stride: int = 1, use_gt_centroids: bool = False, anchor_ind: Optional[int] = None, **kwargs, ): """Initialise the model attributes.""" super().__init__(**kwargs) + if crop_hw is None: + crop_hw = [160, 160] self.torch_model = torch_model self.peak_threshold = peak_threshold self.refinement = refinement self.integral_patch_size = integral_patch_size self.output_stride = output_stride self.return_confmaps = return_confmaps self.max_instances = max_instances self.return_crops = return_crops self.crop_hw = crop_hw self.input_scale = input_scale self.max_stride = max_stride self.use_gt_centroids = use_gt_centroids self.anchor_ind = anchor_indπ Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.def __init__( self, torch_model: Optional[L.LightningModule] = None, output_stride: int = 1, peak_threshold: float = 0.0, max_instances: Optional[int] = None, refinement: Optional[str] = None, integral_patch_size: int = 5, return_confmaps: bool = False, return_crops: bool = False, crop_hw: Optional[List[int]] = None, input_scale: float = 1.0, max_stride: int = 1, use_gt_centroids: bool = False, anchor_ind: Optional[int] = None, **kwargs, ): """Initialise the model attributes.""" super().__init__(**kwargs) if crop_hw is None: crop_hw = [160, 160] self.torch_model = torch_model self.peak_threshold = peak_threshold self.refinement = refinement self.integral_patch_size = integral_patch_size self.output_stride = output_stride self.return_confmaps = return_confmaps self.max_instances = max_instances self.return_crops = return_crops self.crop_hw = crop_hw self.input_scale = input_scale self.max_stride = max_stride self.use_gt_centroids = use_gt_centroids self.anchor_ind = anchor_indsleap_nn/training/model_trainer.py (6)
175-175:
β οΈ Potential issueRemove unnecessary 'f' prefix from string literal
The string on line 175 does not contain any placeholders, so the
fprefix is unnecessary.Apply this diff to correct the string literal:
- f"sleap_nn.training.get_bin_files", + "sleap_nn.training.get_bin_files",π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements."sleap_nn.training.get_bin_files",π§° Tools
πͺ Ruff
175-175: f-string without any placeholders
Remove extraneous
fprefix(F541)
10-10:
β οΈ Potential issueRemove unused import 'shutil'
The
shutilmodule is imported but not used in the code. Eliminating it will tidy up the imports.Apply this diff to remove the unused import:
- import shutilπ Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.π§° Tools
πͺ Ruff
10-10:
shutilimported but unusedRemove unused import:
shutil(F401)
6-6:
β οΈ Potential issueRemove unused import 'inspect'
The
inspectmodule is imported but not used anywhere in the code. Removing it will clean up unnecessary imports.Apply this diff to remove the unused import:
- import inspectπ Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.π§° Tools
πͺ Ruff
6-6:
inspectimported but unusedRemove unused import:
inspect(F401)
57-57:
β οΈ Potential issueRemove unused import 'get_bin_files'
The
get_bin_filesmodule fromsleap_nn.trainingis imported but not utilized. Removing it reduces clutter.Apply this diff to remove the unused import:
- from sleap_nn.training import get_bin_filesπ Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.π§° Tools
πͺ Ruff
57-57:
sleap_nn.training.get_bin_filesimported but unusedRemove unused import:
sleap_nn.training.get_bin_files(F401)
205-205:
β οΈ Potential issueUse 'raise ... from e' to preserve the exception context
When re-raising exceptions, using
raise ... from epreserves the original traceback, aiding in debugging.Apply this diff to improve exception handling:
- raise Exception(f"Error while creating the `.bin` files... {e}") + raise Exception("Error while creating the `.bin` files") from eπ Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.raise Exception("Error while creating the `.bin` files") from eπ§° Tools
πͺ Ruff
205-205: Within an
exceptclause, raise exceptions withraise ... from errorraise ... from Noneto distinguish them from errors in exception handling(B904)
171-191: π οΈ Refactor suggestion
Simplify subprocess execution using 'subprocess.run'
Since you are running the subprocess and waiting for it to finish,
subprocess.runprovides a simpler and more reliable interface with built-in error handling.Refactor the subprocess call as follows:
- def run_subprocess(): - process = subprocess.Popen( - [ - "python", - "-m", - "sleap_nn.training.get_bin_files", - "--dir_path", - f"{self.dir_path}", - "--user_instances_only", - "1" if self.user_instances_only else "0", - "--model_type", - f"{self.model_type}", - "--num_workers", - f"{self.config.trainer_config.train_data_loader.num_workers}", - "--chunk_size", - f"{self.chunk_size}", - "--crop_hw", - f"{self.crop_hw}", - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - # Use communicate() to read output and avoid hanging - stdout, stderr = process.communicate() - # Print the logs - print("Standard Output:\n", stdout) - print("Standard Error:\n", stderr) + def run_subprocess(): + result = subprocess.run( + [ + "python", + "-m", + "sleap_nn.training.get_bin_files", + "--dir_path", + self.dir_path, + "--user_instances_only", + "1" if self.user_instances_only else "0", + "--model_type", + self.model_type, + "--num_workers", + str(self.config.trainer_config.train_data_loader.num_workers), + "--chunk_size", + str(self.chunk_size), + "--crop_hw", + str(self.crop_hw), + ], + capture_output=True, + text=True, + check=True, + ) + # Log the outputs + logging.info(f"Standard Output:\n{result.stdout}") + logging.error(f"Standard Error:\n{result.stderr}")This refactoring uses
subprocess.runwithcheck=Trueto automatically raise aCalledProcessErrorif the subprocess exits with a non-zero status, simplifying error handling.Committable suggestion was skipped due to low confidence.
π§° Tools
πͺ Ruff
175-175: f-string without any placeholders
Remove extraneous
fprefix(F541)
sleap_nn/inference/predictors.py (2)
273-279: π οΈ Refactor suggestion
Refactor duplicate image color conversion logic
The code for converting images to RGB or grayscale is duplicated in lines 273-279 and 304-307. To improve maintainability and reduce redundancy, consider refactoring this logic into a separate helper function or moving it outside the loop if possible.
Apply this diff to eliminate the duplication:
@@ -273,14 +273,6 @@ if self.instances_key: frame["instances"] = frame["instances"] * eff_scale - if self.preprocess_config["is_rgb"] and frame["image"].shape[-3] != 3: - frame["image"] = frame["image"].repeat(1, 3, 1, 1) - elif not self.preprocess_config["is_rgb"]: - frame["image"] = F.rgb_to_grayscale( - frame["image"], num_output_channels=1 - ) - eff_scales.append(torch.tensor(eff_scale)) imgs.append(frame["image"].unsqueeze(dim=0)) fidxs.append(frame["frame_idx"]) vidxs.append(frame["video_idx"]) @@ -304,7 +296,14 @@ ex["instances"] = instances + # Perform image color conversion + if self.preprocess_config["is_rgb"] and ex["image"].shape[-3] != 3: + ex["image"] = ex["image"].repeat(1, 1, 3, 1, 1) + elif not self.preprocess_config["is_rgb"]: + ex["image"] = F.rgb_to_grayscale(ex["image"], num_output_channels=1) + if self.preprocess:Also applies to: 304-307
636-645: π οΈ Refactor suggestion
Consolidate preprocessing configuration across predictors
Similar code for setting up
preprocess_configwithmax_heightandmax_widthappears in multiple predictor classes (TopDownPredictor,SingleInstancePredictor,BottomUpPredictor). Consider consolidating this logic into a shared method or base class to reduce redundancy and ensure consistency across different predictors.Also applies to: 657-666, 693-702, 962-971, 990-999, 1289-1298, 1317-1326
Codecov ReportAll modified and coverable lines are covered by tests β
Additional details and impacted files@@ Coverage Diff @@
## main #106 +/- ##
==========================================
+ Coverage 96.64% 97.39% +0.74%
==========================================
Files 23 38 +15
Lines 1818 3681 +1863
==========================================
+ Hits 1757 3585 +1828
- Misses 61 96 +35 β View full report in Codecov by Sentry. |
22c7a96 to
105e496
Compare
162e949 to
f64b425
Compare
f64b425 to
162e949
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
π§Ή Outside diff range and nitpick comments (11)
tests/data/test_streaming_datasets.py (2)
31-35: Consider adding edge case tests for max_hw parameter.While the test coverage for different scales is good, consider adding test cases for:
- Edge cases where max_hw is very small or large
- Cases where image dimensions are smaller than max_hw
Also applies to: 69-75
125-126: Consider adding scale=1.0 test case for consistency.Other test functions include test cases for both scale=0.5 and scale=1.0. Consider adding a similar second test case here for consistency and better coverage.
sleap_nn/data/get_data_chunks.py (2)
272-276: Enhance documentation for max_hw parameter.While the documentation explains the parameter well, consider adding an example showing how the values from config and max_hw interact to determine the final dimensions.
Example addition to the docstring:
Example: If config has max_height=1000, max_width=None and max_hw=(800, 1200), the final maximum dimensions will be (1000, 1200).
Line range hint
1-318: Consider extracting common image processing logic.The file shows repeated patterns for image processing across all functions:
- Size matching with max_hw
- Coordinate scaling
- Additional resizing
- PIL conversion
Consider extracting these steps into a utility function to reduce code duplication and ensure consistent processing across all data chunk functions.
Example utility function:
def process_image_and_coordinates( image: torch.Tensor, coordinates: torch.Tensor, max_height: Optional[int], max_width: Optional[int], max_hw: Tuple[int, int], scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, float]: """Process image and scale coordinates accordingly.""" # Size matching image, eff_scale = apply_sizematcher( image, max_height=max_height if max_height is not None else max_hw[0], max_width=max_width if max_width is not None else max_hw[1], ) coordinates = coordinates * eff_scale # Additional resizing image, coordinates = apply_resizer(image, coordinates, scale=scale) # PIL conversion transform = T.ToPILImage() image = transform(image.squeeze(dim=0)) return image, coordinates, eff_scalesleap_nn/data/streaming_datasets.py (2)
131-131: Enhance documentation for input_scale parameter.While the
input_scaleparameter is correctly implemented, its documentation could be more detailed to explain:
- The valid range of values
- How it affects the final crop size
- Any performance implications
Consider expanding the docstring:
- input_scale: Resize factor applied to the image. Default: 1.0. + input_scale: Scaling factor applied to crop dimensions. Values > 1.0 increase crop size, + while values < 1.0 decrease it. Must be positive. Default: 1.0.Also applies to: 141-141, 153-153
322-324: LGTM: Image preprocessing is consistent with other classes.The transformation pipeline correctly follows the established pattern.
Consider extracting the common image preprocessing logic into a shared utility function to reduce code duplication across dataset classes. Example:
def preprocess_image(image: PIL.Image, apply_aug: bool = False, aug_config: Optional[DictConfig] = None) -> torch.Tensor: """Apply standard image preprocessing pipeline. Args: image: Input PIL image apply_aug: Whether to apply augmentation aug_config: Augmentation configuration Returns: Preprocessed image tensor """ transform = T.PILToTensor() image = transform(image) image = image.unsqueeze(dim=0).to(torch.float32) if apply_aug: if "intensity" in aug_config: image, _ = apply_intensity_augmentation(image, None, **aug_config.intensity) if "geometric" in aug_config: image, _ = apply_geometric_augmentation(image, None, **aug_config.geometric) image = apply_normalization(image) return imageAlso applies to: 338-338
tests/training/test_model_trainer.py (1)
356-356: LGTM! Consider adding edge cases for RGB configuration.The RGB configuration changes are well-coordinated between preprocessing and model configuration. However, consider adding test cases for:
- Mismatched configurations (is_rgb=True with in_channels=1)
- Grayscale to RGB conversion behavior
Also applies to: 362-362
sleap_nn/training/model_trainer.py (3)
92-92: Consider using None as default for crop_hwUsing -1 as a default value for
crop_hwcould be misleading. Consider usingNoneinstead, as it better represents an uninitialized state and matches Python conventions.- self.crop_hw = -1 + self.crop_hw = None
382-388: Consider atomic file operations for config savingThe config saving operations could potentially fail if there are concurrent writes. Consider using atomic file operations to ensure config files are written safely.
Consider using a temporary file and atomic rename operation:
import tempfile import os def save_config_atomic(config, filepath): with tempfile.NamedTemporaryFile(mode='w', delete=False) as tf: OmegaConf.save(config=config, f=tf.name) os.replace(tf.name, filepath)
Remove unused
inspectimport and keep other importsThe
inspectmodule is imported but never used in the code. However, bothshutilandget_bin_filesare actively used:
shutilis used in commented-out cleanup code and is commonly used across testsget_bin_filesis used in subprocess calls for data processingRemove only the
inspectimport to maintain clean imports while preserving required functionality.π Analysis chain
Line range hint
6-57: Clean up unused importsThe following imports are flagged as unused and should be removed if they're not needed elsewhere in the codebase:
inspect(line 6)shutil(line 10)sleap_nn.training.get_bin_files(line 57)
Let me analyze the usage of these imports more thoroughly before making a final decision.
π Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Verify if these imports are used elsewhere rg "\b(inspect|shutil|get_bin_files)\b" --type pyLength of output: 4076
Script:
#!/bin/bash # Check for actual usage of inspect and get_bin_files in model_trainer.py rg -A 5 "\b(inspect|get_bin_files)\b" sleap_nn/training/model_trainer.py # Check for shutil usage in model_trainer.py rg -A 5 "\bshutil\." sleap_nn/training/model_trainer.pyLength of output: 1122
π§° Tools
πͺ Ruff
37-37:
torchvision.models.convnext.ConvNeXt_Small_Weightsimported but unusedRemove unused import
(F401)
38-38:
torchvision.models.convnext.ConvNeXt_Large_Weightsimported but unusedRemove unused import
(F401)
40-40: Redefinition of unused
siofrom line 13Remove definition:
sio(F811)
42-42:
sleap_nn.data.cycler.CyclerIterDataPipeimported but unusedRemove unused import:
sleap_nn.data.cycler.CyclerIterDataPipe(F401)
46-46:
sleap_nn.data.get_data_chunks.bottomup_data_chunksimported but unusedRemove unused import
(F401)
47-47:
sleap_nn.data.get_data_chunks.centered_instance_data_chunksimported but unusedRemove unused import
(F401)
sleap_nn/inference/predictors.py (1)
1624-1626: Fix indentation for better readability.The indentation of this block appears to be inconsistent with the surrounding code.
Apply this diff to fix the indentation:
- if isinstance(predictor, TopDownPredictor): - predictor.anchor_ind = anchor_ind - + if isinstance(predictor, TopDownPredictor): + predictor.anchor_ind = anchor_ind +
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
π Files selected for processing (8)
sleap_nn/data/get_data_chunks.py(13 hunks)sleap_nn/data/streaming_datasets.py(11 hunks)sleap_nn/inference/predictors.py(16 hunks)sleap_nn/training/get_bin_files.py(1 hunks)sleap_nn/training/model_trainer.py(9 hunks)tests/data/test_streaming_datasets.py(11 hunks)tests/fixtures/datasets.py(2 hunks)tests/training/test_model_trainer.py(8 hunks)
π§ Files skipped from review as they are similar to previous changes (1)
- tests/fixtures/datasets.py
π§° Additional context used
πͺ Ruff
sleap_nn/training/get_bin_files.py
34-34: Use not ... instead of False if ... else True
Replace with not ...
(SIM211)
sleap_nn/training/model_trainer.py
6-6: inspect imported but unused
Remove unused import: inspect
(F401)
10-10: shutil imported but unused
Remove unused import: shutil
(F401)
57-57: sleap_nn.training.get_bin_files imported but unused
Remove unused import: sleap_nn.training.get_bin_files
(F401)
174-174: f-string without any placeholders
Remove extraneous f prefix
(F541)
206-206: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
π Additional comments (14)
tests/data/test_streaming_datasets.py (2)
13-13: LGTM: Consistent integration of max_hw calculation.
The addition of get_max_height_width import and its consistent usage across all test functions is well-structured. The calculation is properly placed before data processing in each test function.
Also applies to: 25-25, 115-115, 160-160, 247-247
169-170: LGTM: Comprehensive test coverage for centroid streaming.
The test properly covers both scaling scenarios and includes augmentation testing. The structure is consistent with other tests and includes proper cleanup.
Also applies to: 202-209
sleap_nn/data/get_data_chunks.py (4)
7-7: LGTM: Import addition aligns with new image transformation requirements.
The addition of torchvision.transforms is appropriate for the new PIL image conversion functionality implemented across all data chunk functions.
161-163: Good: Explicit float32 conversion before crop generation.
The explicit conversion to float32 before crop generation helps prevent potential precision issues. The subsequent resizing and PIL conversion are properly applied to each crop.
Also applies to: 170-176
237-238: LGTM: Proper centroid coordinate scaling.
The implementation correctly scales centroid coordinates using both the effective scale from size matching and the additional scale parameter. The order of operations (scaling β resizing β PIL conversion) is maintained consistently.
Also applies to: 244-251
24-26: Verify the impact of double scaling on instance coordinates.
The implementation looks correct, but we should verify that the combination of eff_scale from apply_sizematcher and the additional scale parameter in apply_resizer maintains proper coordinate scaling for instances.
Let's verify the scaling implementation:
Also applies to: 63-83
β Verification successful
Scaling implementation is correct and properly tested
The implementation and tests show that:
apply_resizercorrectly scales both images and instance coordinates by the same factor- Test case explicitly verifies that instance coordinates are scaled properly:
assert torch.all(instances == ex["instances"] * 2.0) - The scaling operations in
apply_sizematcherandapply_resizerare independent and not compounded - they are applied in sequence with proper coordinate transformation
π Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any scaling-related issues in tests or related files
# Look for test cases involving scaling
rg -A 5 "scale.*=.*" "tests/data/"
# Check for any reported scaling issues
rg -i "scale.*issue|coordinate.*mismatch" "tests/" "issues/"
Length of output: 5553
Script:
#!/bin/bash
# Let's check the implementation of apply_sizematcher and apply_resizer to verify scaling logic
# Check implementation of size matching and scaling
ast-grep --pattern 'def apply_sizematcher($$$)'
# Check implementation of resizer
ast-grep --pattern 'def apply_resizer($$$)'
# Look for any tests specifically verifying coordinate scaling
rg -B 2 -A 5 "assert.*instances.*scale|assert.*coordinate" "tests/"
Length of output: 163
Script:
#!/bin/bash
# Let's try a different approach to find the scaling implementation and tests
# Search for files containing scaling-related functions
rg -l "apply_sizematcher|apply_resizer" --type py
# Once we find the files, let's look at their content
rg -B 5 -A 10 "def apply_sizematcher|def apply_resizer" --type py
# Check test assertions related to scaling
rg -B 2 -A 5 "assert.*\*.*2\.0|assert.*scale" "tests/data/test_resizing.py"
Length of output: 2588
sleap_nn/data/streaming_datasets.py (4)
8-8: LGTM: Import changes are well-organized.
The new imports for torchvision.transforms and apply_normalization are appropriately placed and support the new image transformation pipeline.
Also applies to: 16-16
65-67: LGTM: Image preprocessing pipeline is well-structured.
The transformation sequence follows best practices:
- PIL to Tensor conversion
- Proper tensor formatting (batch dimension + float32)
- Augmentation (when enabled)
- Normalization
- Stride padding
Also applies to: 81-81
247-249: LGTM: Consistent image preprocessing implementation.
The transformation pipeline matches the pattern established in other dataset classes, maintaining consistency across the codebase.
Also applies to: 263-263
Line range hint 1-354: Verify PR scope and description.
The changes in this file focus on standardizing image preprocessing across dataset classes, but the PR description mentions adding a threshold mode for the learning rate scheduler. This seems to be a mismatch between the PR objective and the actual changes.
Let's check if there are other files that implement the threshold mode:
sleap_nn/training/model_trainer.py (1)
Line range hint 624-630: LGTM: Learning rate scheduler threshold mode implementation
The addition of the threshold_mode parameter to the learning rate scheduler is well-implemented and aligns with the PR objectives.
sleap_nn/inference/predictors.py (3)
78-79: LGTM: Clean addition of max dimension parameters.
The new optional parameters max_height and max_width are properly added to the base configuration dictionary, maintaining backward compatibility.
265-279: LGTM: Well-structured image preprocessing pipeline.
The preprocessing pipeline is properly implemented with:
- Correct ordering of operations (resize -> color conversion -> normalization)
- Modern usage of torchvision.transforms.v2 for color space conversion
- Proper handling of effective scale for instance coordinates
427-427: LGTM: Clean implementation of anchor_ind and optional crop_hw.
The changes properly:
- Add optional anchor_ind with good fallback to config value
- Make crop_hw optional with appropriate fallback
- Maintain backward compatibility
Also applies to: 443-456
162e949 to
62e7406
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
π§Ή Outside diff range and nitpick comments (2)
sleap_nn/training/model_trainer.py (2)
Line range hint
206-206: Improve error handling in subprocess execution.The current error handling could be more specific and provide better context for debugging.
Apply this improvement:
- raise Exception(f"Error while creating the `.bin` files... {e}") + raise Exception("Failed to create .bin files: subprocess execution error") from e
Based on the gathered context, I can now provide a final response:
The commented-out cleanup code should be re-enabled with proper resource management
The test failure on Ubuntu is related to the subprocess spawned by
_create_data_loaders()which runs for over 6 hours. The issue isn't with the cleanup code itself, but rather with the data chunk generation process that's using multiprocessing workers. The cleanup code is actually necessary to prevent resource leaks.Recommendations:
- Re-enable the cleanup code as it's essential for proper resource management
- The actual issue needs to be addressed in the data chunk generation process by:
- Adding timeouts to the subprocess call
- Limiting the number of workers in the test configuration
- Reducing the test data size for this specific test
π Analysis chain
Line range hint
447-456: Address the TODO comment about Ubuntu test failure.The commented-out cleanup code and TODO comment indicate an unresolved issue with the Ubuntu test that's running for over 6 hours.
Let's investigate the cleanup issue:
Would you like help implementing a more robust cleanup mechanism that won't cause test timeouts?
Let me gather more context about the TODO comment and test failures.
Let me check the specific code around the TODO comment.
Let me check the test file to understand the test configuration and the TODO comment better.
Let me check the test file for the specific test that's being skipped and its configuration.
π Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for potential file handle leaks or long-running processes # Test: Look for any file operations or long-running processes in the codebase rg -A 5 "shutil\.(rmtree|copy|move)" rg -A 5 "subprocess\.(Popen|run|call)"Length of output: 11815
Script:
#!/bin/bash # Look for the TODO comment about Ubuntu test failure rg -B 5 -A 5 "TODO.*Ubuntu" # Check the test file for any timing related configurations rg -l "test_model_trainer" | xargs cat # Look for any process spawning or multiprocessing code ast-grep --pattern 'subprocess.Popen($$$)' ast-grep --pattern 'multiprocessing.$$$'Length of output: 25812
Script:
#!/bin/bash # Look for the code context around line 447-456 in model_trainer.py rg -B 10 -A 10 "TODO.*Revisit this test later" sleap_nn/training/model_trainer.py # Check for any multiprocessing configurations in the test rg -A 5 "num_workers|multiprocessing" sleap_nn/training/model_trainer.py # Look for the get_bin_files module that's spawned as a subprocess fd -e py get_bin_files | xargs catLength of output: 6666
Script:
#!/bin/bash # Look for the test file and the specific test that's being skipped rg -B 5 -A 5 "@pytest.mark.skipif.*Ubuntu" tests/ # Check for any test configurations related to timeouts rg -A 5 "timeout|duration" tests/ # Look for any cleanup related code in the test rg -A 5 "cleanup|rmtree" tests/training/test_model_trainer.pyLength of output: 4439
Script:
#!/bin/bash # Find the specific test file and the skipped test rg -B 20 -A 20 "reason=\"Flaky test.*Ubuntu" tests/ # Check the model trainer's data loading and subprocess code rg -B 10 -A 10 "_create_data_loaders" sleap_nn/training/model_trainer.py # Look for any pytest configuration files fd -e ini -e cfg -e conf pytest | xargs catLength of output: 4734
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
π Files selected for processing (5)
docs/config_bottomup.yaml(1 hunks)docs/config_centroid.yaml(1 hunks)docs/config_topdown_centered_instance.yaml(1 hunks)sleap_nn/training/model_trainer.py(1 hunks)tests/fixtures/datasets.py(1 hunks)
π§ Files skipped from review as they are similar to previous changes (4)
- docs/config_bottomup.yaml
- docs/config_centroid.yaml
- docs/config_topdown_centered_instance.yaml
- tests/fixtures/datasets.py
π Additional comments (1)
sleap_nn/training/model_trainer.py (1)
624-624: LGTM: Learning rate scheduler threshold mode implementation.
The addition of threshold_mode parameter to the ReduceLROnPlateau scheduler is correctly implemented, allowing for both absolute and relative threshold modes as specified in the PR objectives.
This PR adds a parameter
trainer_config.lr_scheduler.threshold_modeto the config to specify the threshold mode for learning rate scheduler (eitherabsorrel).Summary by CodeRabbit
Release Notes
New Features
threshold_modeparameter for enhanced learning rate scheduling in multiple configuration files.entityfield for Weights and Biases (WandB) tracking in training configurations.Bug Fixes
Refactor
Chores