Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@gitttt-1234
Copy link
Collaborator

@gitttt-1234 gitttt-1234 commented Oct 23, 2024

This PR adds a parameter trainer_config.lr_scheduler.threshold_mode to the config to specify the threshold mode for learning rate scheduler (either abs or rel).

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced threshold_mode parameter for enhanced learning rate scheduling in multiple configuration files.
    • Added entity field for Weights and Biases (WandB) tracking in training configurations.
  • Bug Fixes

    • Improved error handling for subprocess execution during data loading.
  • Refactor

    • Updated data loading methods for better modularity and clarity in the model training process.
  • Chores

    • Simplified configuration files by removing unused options.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 23, 2024

Walkthrough

The pull request introduces several modifications to configuration files and the model trainer implementation. Key changes include the addition of a threshold_mode parameter to the learning rate scheduler in multiple configuration files and the introduction of an entity field in the WandB section of the docs/config_centroid.yaml file. The model trainer has been updated to enhance data loading and configuration handling, including the removal of an outdated method and improvements to error handling. These changes aim to refine the configuration options and improve the modularity of the training process.

Changes

File Change Summary
docs/config_bottomup.yaml Added threshold_mode under lr_scheduler; removed commented-out sections for backbones.
docs/config_centroid.yaml Added entity under wandb and threshold_mode under lr_scheduler.
docs/config_topdown_centered_instance.yaml Added threshold_mode under trainer_config.lr_scheduler.
sleap_nn/training/model_trainer.py Updated _create_data_loaders, removed _get_data_chunks, added crop_hw and chunk_size, modified initialization and error handling.
tests/fixtures/datasets.py Changed entity in wandb config from "team-ucsd" to None; added threshold_mode set to "rel".

Possibly related PRs

Suggested reviewers

  • talmo

πŸ‡ In the meadow, changes bloom,
New configs chase away the gloom.
Thresholds set, entities in sight,
Training models, oh what a delight!
With every hop, we refine our way,
In the world of code, we play all day! 🌼


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❀️ Share
πŸͺ§ Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@gitttt-1234 gitttt-1234 changed the base branch from main to divya/fix-get-chunks October 23, 2024 17:13
@gitttt-1234 gitttt-1234 requested a review from talmo October 23, 2024 17:14
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments.

Actionable comments posted: 27

🧹 Outside diff range and nitpick comments (16)
docs/config_bottomup.yaml (1)

Line range hint 1-107: Configuration structure maintains good organization and clarity.

The removal of commented-out alternative backbone configurations (ConvNeXt and Swin) improves maintainability while preserving the clear hierarchical structure of the configuration.

Consider maintaining alternative backbone configurations in separate example files if they might be useful for reference, rather than as comments in the main configuration.

docs/config_topdown_centered_instance.yaml (1)

105-105: LGTM! Consider adding parameter documentation.

The threshold_mode parameter is correctly placed in the lr_scheduler configuration with a valid value of 'abs'.

Consider adding a comment to document that this parameter accepts either 'abs' or 'rel' values and their implications on the learning rate scheduling behavior.

    threshold: 1.0e-07
+   # threshold_mode: Mode for interpreting threshold parameter ('abs' for absolute, 'rel' for relative change)
    threshold_mode: abs
initial_config.yaml (1)

68-69: Improve empty path handling

Empty string defaults for paths could lead to unclear errors. Consider using null as default or providing clear placeholder paths.

   save_ckpt: false
-  save_ckpt_path: ''
+  save_ckpt_path: null  # Set to "./checkpoints" when save_ckpt is true
docs/config_centroid.yaml (1)

136-136: Document the valid values for threshold_mode.

While the default value abs is set correctly, please add a comment documenting that the valid values are 'abs' (absolute) or 'rel' (relative) for the threshold_mode parameter.

Apply this diff to add documentation:

    threshold: 1.0e-07
-    threshold_mode: abs
+    threshold_mode: abs  # Valid values: 'abs' (absolute) or 'rel' (relative)
    cooldown: 3
tests/fixtures/datasets.py (1)

Line range hint 156-164: Add missing threshold_mode parameter in lr_scheduler config

The PR's main objective is to add the threshold_mode parameter to the learning rate scheduler configuration, but it's missing in this test fixture. This could lead to inconsistency with other configuration files that have been updated.

Apply this diff to align with other config files:

                "lr_scheduler": {
                    "threshold": 1e-07,
+                   "threshold_mode": "rel",
                    "cooldown": 3,
                    "patience": 5,
                    "factor": 0.5,
                    "min_lr": 1e-08,
                },
sleap_nn/training/get_bin_files.py (1)

1-1: Enhance the module docstring with more details.

The current docstring is too brief. Consider adding more details about:

  • The purpose and use cases of the binary files
  • Required command-line arguments
  • Supported model types
  • Example usage

Here's a suggested improvement:

-"""Function to generate `.bin` files."""
+"""Generate optimized binary files for model training.
+
+This script processes training and validation data to generate optimized .bin files
+for different model types (single_instance, centered_instance, centroid, bottomup).
+
+Example usage:
+    python get_bin_files.py --dir_path /path/to/data \
+                           --model_type single_instance \
+                           --num_workers 4 \
+                           --chunk_size 1000
+
+Required arguments:
+    --dir_path: Directory containing initial_config.yaml
+    --model_type: Type of model (single_instance, centered_instance, centroid, bottomup)
+    --num_workers: Number of worker processes
+    --chunk_size: Size of data chunks
+    --user_instances_only: Filter for user instances
+    --crop_hw: Crop height/width (required for centered_instance)
+"""
tests/data/test_get_data_chunks.py (1)

Line range hint 1-261: Consider reducing test duplication with fixtures

There's significant duplication in the transform initialization and shape validation across all test functions. Consider creating pytest fixtures to reduce this duplication.

Example refactor:

import pytest

@pytest.fixture
def transform():
    return T.ToTensor()

@pytest.fixture
def max_hw(labels):
    return get_max_height_width(labels)

def assert_image_shapes(sample, channels, height, width):
    """Helper to validate image shapes consistently."""
    assert transform(sample["image"]).shape == (channels, height, width)
sleap_nn/data/instance_cropping.py (1)

50-55: LGTM! Good defensive programming practice.

The added NaN handling improves the robustness of the crop size calculation. The implementation correctly prevents NaN propagation while maintaining the function's contract.

Consider extracting the NaN handling into a helper function for better readability:

+def safe_diff(max_val: float, min_val: float) -> float:
+    """Calculate difference with NaN handling."""
+    diff = max_val - min_val
+    return 0 if np.isnan(diff) else diff

 def find_instance_crop_size(...):
     # ...
     for lf in labels:
         for inst in lf.instances:
             pts = inst.numpy()
             pts *= input_scaling
-            diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
-            diff_x = 0 if np.isnan(diff_x) else diff_x
+            diff_x = safe_diff(np.nanmax(pts[:, 0]), np.nanmin(pts[:, 0]))
             max_length = np.maximum(max_length, diff_x)
-            diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
-            diff_y = 0 if np.isnan(diff_y) else diff_y
+            diff_y = safe_diff(np.nanmax(pts[:, 1]), np.nanmin(pts[:, 1]))
             max_length = np.maximum(max_length, diff_y)
sleap_nn/data/get_data_chunks.py (1)

247-251: Minor: Maintain consistent documentation formatting.

There's an extra period after "config file" in the documentation that isn't present in other similar docstrings.

-        data_config: Data-related configuration. (`data_config` section in the config file).
+        data_config: Data-related configuration. (`data_config` section in the config file)
tests/inference/test_topdown.py (1)

216-216: Consider adding test cases with different eff_scale values.

The eff_scale tensor is set to 1.0 in all test cases. Consider adding test cases with different scale values to verify the model's behavior with various scaling factors.

sleap_nn/data/providers.py (1)

31-35: Add error handling and improve documentation.

While the implementation is efficient, consider these improvements:

  1. Add error handling for empty labels/videos
  2. Enhance docstring with Args/Returns sections and edge cases

Consider this improved implementation:

 def get_max_height_width(labels: sio.Labels) -> Tuple[int, int]:
-    """Return `(height, width)` that is the maximum of all videos."""
+    """Calculate the maximum height and width across all videos in the labels.
+
+    Args:
+        labels: sleap_io.Labels object containing video data
+
+    Returns:
+        Tuple[int, int]: Maximum (height, width) across all videos
+
+    Raises:
+        ValueError: If labels contains no videos
+    """
+    if not labels.videos:
+        raise ValueError("Labels object contains no videos")
     return max(video.shape[1] for video in labels.videos), max(
         video.shape[2] for video in labels.videos
     )
tests/training/test_model_trainer.py (1)

Line range hint 1-507: Add test coverage for the new threshold_mode parameter.

The PR introduces a new threshold_mode parameter for the learning rate scheduler, but there are no test cases verifying this functionality. Please add test cases to:

  1. Verify both abs and rel threshold modes work as expected
  2. Ensure proper error handling for invalid threshold mode values

Here's a suggested test case structure:

def test_lr_scheduler_threshold_mode(config, tmp_path: str):
    """Test learning rate scheduler with different threshold modes."""
    # Test absolute threshold mode
    config_abs = config.copy()
    OmegaConf.update(config_abs, "trainer_config.lr_scheduler.threshold_mode", "abs")
    trainer_abs = ModelTrainer(config_abs)
    # Add assertions to verify absolute threshold behavior

    # Test relative threshold mode
    config_rel = config.copy()
    OmegaConf.update(config_rel, "trainer_config.lr_scheduler.threshold_mode", "rel")
    trainer_rel = ModelTrainer(config_rel)
    # Add assertions to verify relative threshold behavior

    # Test invalid threshold mode
    config_invalid = config.copy()
    OmegaConf.update(config_invalid, "trainer_config.lr_scheduler.threshold_mode", "invalid")
    with pytest.raises(ValueError, match="threshold_mode must be 'rel' or 'abs'"):
        trainer_invalid = ModelTrainer(config_invalid)
sleap_nn/data/resizing.py (2)

111-112: Clarify the function's behavior in the docstring

The docstring states: "Apply scaling and padding to smaller image to (max_height, max_width) shape." However, the function raises an error if the image dimensions exceed the maximum dimensions. Consider updating the docstring to accurately reflect that images larger than the maximum dimensions will cause an error.


150-152: Return consistent eff_scale_ratio for all cases

Currently, when the image dimensions match the maximum dimensions, the function returns eff_scale_ratio as 1.0. Ensure that this value accurately reflects any scaling applied, including when no scaling is needed.

sleap_nn/training/model_trainer.py (1)

197-199: Use logging instead of printing to standard output

Directly printing to standard output may not be suitable for libraries or production code. Consider using the logging module for better control over logging levels and outputs.

Apply this diff to replace print statements with logging:

+ import logging
...
- print("Standard Output:\n", stdout)
- print("Standard Error:\n", stderr)
+ logging.info(f"Standard Output:\n{stdout}")
+ logging.error(f"Standard Error:\n{stderr}")

Don't forget to configure the logging level and handlers as appropriate for your application.

sleap_nn/inference/predictors.py (1)

1623-1624: Consider broader support for anchor_ind

In the main function, anchor_ind is assigned to the predictor only if it's an instance of TopDownPredictor. If other predictor types could benefit from anchor_ind, consider implementing a more generic approach or documenting why it's specific to TopDownPredictor.

πŸ›‘ Comments failed to post (27)
docs/config_bottomup.yaml (1)

107-107: πŸ’‘ Codebase verification

The threshold_mode parameter needs to be added to test configuration files

The verification reveals an inconsistency in the configuration files. While the main configuration files in the docs/ directory (config_bottomup.yaml, config_centroid.yaml, config_topdown_centered_instance.yaml) correctly include the threshold_mode: abs parameter, the test configuration files in tests/assets/ and the root initial_config.yaml are missing this parameter.

Files requiring updates:

  • tests/assets/minimal_instance_centroid/training_config.yaml
  • tests/assets/minimal_instance_centroid/initial_config.yaml
  • tests/assets/minimal_instance_bottomup/training_config.yaml
  • tests/assets/minimal_instance_bottomup/initial_config.yaml
  • tests/assets/minimal_instance/training_config.yaml
  • tests/assets/minimal_instance/initial_config.yaml
  • initial_config.yaml
πŸ”— Analysis chain

LGTM! The threshold_mode parameter is correctly configured.

The addition of threshold_mode: abs is properly placed under the lr_scheduler section and works in conjunction with the existing threshold parameter. This enhancement allows for more flexible learning rate scheduling by supporting both absolute and relative threshold comparisons.

Let's verify if other configuration files have consistent threshold_mode settings:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for consistent lr_scheduler threshold_mode across config files
# Expected: All config files should have the same structure for lr_scheduler

# Search for lr_scheduler configurations in all yaml files
fd -e yaml -e yml | xargs rg -A 10 "lr_scheduler:"

Length of output: 6627

initial_config.yaml (4)

1-93: ⚠️ Potential issue

Fix inconsistent newline characters

The file uses incorrect newline characters. Ensure consistent use of Unix-style newlines (\n) throughout the file.

Use a text editor that supports Unix-style line endings or run:

dos2unix initial_config.yaml
🧰 Tools
πŸͺ› yamllint

[error] 1-1: wrong new line character: expected \n

(new-lines)


88-93: ⚠️ Potential issue

Add missing threshold_mode parameter to lr_scheduler

According to the PR objectives, a new threshold_mode parameter should be added to the learning rate scheduler configuration to support both 'abs' and 'rel' modes.

Add the parameter with a default value:

 lr_scheduler:
   threshold: 1.0e-07
+  threshold_mode: 'rel'  # Options: 'rel' or 'abs'
   cooldown: 3
   patience: 5
   factor: 0.5
   min_lr: 1.0e-08
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

  lr_scheduler:
    threshold: 1.0e-07
    threshold_mode: 'rel'  # Options: 'rel' or 'abs'
    cooldown: 3
    patience: 5
    factor: 0.5
    min_lr: 1.0e-08

76-76: ⚠️ Potential issue

Security: Remove API key field from version control

Even though it's empty, the api_key field shouldn't be in version control. Consider moving sensitive fields to environment variables or a separate secrets file.

Replace with environment variable:

-    api_key: ''
+    api_key: ${WANDB_API_KEY}
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    api_key: ${WANDB_API_KEY}

3-4: ⚠️ Potential issue

Replace hardcoded Windows paths with environment-agnostic paths

The configuration contains hardcoded Windows paths pointing to test assets. This makes the configuration non-portable and environment-specific.

Consider:

  1. Using relative paths or environment variables
  2. Using forward slashes for cross-platform compatibility
  3. Providing example paths that don't reference test assets
-  train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
-  val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
+  train_labels_path: ${DATA_DIR}/train/labels.slp
+  val_labels_path: ${DATA_DIR}/val/labels.slp

Committable suggestion was skipped due to low confidence.

sleap_nn/inference/single_instance.py (1)

87-89: ⚠️ Potential issue

Add input validation and documentation for eff_scale.

While the scaling implementation looks correct, there are a few improvements needed:

  1. Add validation for the existence and validity of inputs["eff_scale"]
  2. Document the expected shape and format of eff_scale in the method's docstring
  3. Consider handling potential division by zero

Here's a suggested implementation:

     def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
         """Predict confidence maps and infer peak coordinates.
 
         Args:
             inputs: Dictionary with "image" as one of the keys.
+                   "eff_scale": Tensor of shape (batch_size,) containing 
+                               effective scale factors for each sample.
 
         Returns:
             A dictionary of outputs with keys:
 
             `"pred_instance_peaks"`: The predicted peaks for each instance in the batch
                 as a `torch.Tensor` of shape `(samples, nodes, 2)`.
             `"pred_peak_vals"`: The value of the confidence maps at the predicted
                 peaks for each instance in the batch as a `torch.Tensor` of shape
                 `(samples, nodes)`.
+
+        Raises:
+            KeyError: If "eff_scale" is missing from inputs
+            ValueError: If "eff_scale" contains zero values
         """
         # Network forward pass.
         cms = self.torch_model(inputs["image"])
         ...
+        if "eff_scale" not in inputs:
+            raise KeyError("Missing required input: eff_scale")
+        
+        eff_scale = inputs["eff_scale"]
+        if torch.any(eff_scale == 0):
+            raise ValueError("eff_scale contains zero values")
+
         peak_points = peak_points / (
             inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2)
         ).to(peak_points.device)

Committable suggestion was skipped due to low confidence.

sleap_nn/training/get_bin_files.py (5)

28-38: ⚠️ Potential issue

Add validation for configuration and arguments.

The initialization lacks validation for:

  • Required configuration fields
  • Crop dimensions when using centered_instance model
  • Existence of train/val label paths

Add validation code after loading the config:

def validate_config(config, args):
    """Validate configuration and arguments."""
    required_fields = [
        "data_config.train_labels_path",
        "data_config.val_labels_path",
        "model_config.backbone_config.max_stride",
    ]
    
    for field in required_fields:
        if not OmegaConf.select(config, field):
            raise ValueError(f"Missing required config field: {field}")
            
    if args.model_type == "centered_instance" and args.crop_hw is None:
        raise ValueError("crop_hw is required for centered_instance model type")
        
    for path_field in ["train_labels_path", "val_labels_path"]:
        path = Path(OmegaConf.select(config.data_config, path_field))
        if not path.exists():
            raise FileNotFoundError(f"Labels file not found: {path}")

# Add after loading config
validate_config(config, args)
🧰 Tools
πŸͺ› Ruff

33-33: Use not ... instead of False if ... else True

Replace with not ...

(SIM211)


40-143: πŸ› οΈ Refactor suggestion

Refactor repeated optimization code and add error handling.

The code has several areas for improvement:

  1. Repeated ld.optimize calls could be extracted into a helper function
  2. Missing progress tracking
  3. No error handling for the optimization process

Here's a suggested refactoring:

def process_dataset(factory_get_chunks, labels, output_dir, num_workers, chunk_size):
    """Process a dataset using the given chunk factory."""
    try:
        print(f"Processing {len(labels)} samples to {output_dir}...")
        inputs = [(x, labels.videos.index(x.video)) for x in labels]
        ld.optimize(
            fn=factory_get_chunks,
            inputs=inputs,
            output_dir=output_dir,
            num_workers=num_workers,
            chunk_size=chunk_size,
        )
        print(f"Successfully processed {len(inputs)} samples")
    except Exception as e:
        print(f"Error processing dataset: {str(e)}")
        raise

# Usage example for single_instance model type:
if args.model_type == "single_instance":
    factory_get_chunks = functools.partial(
        single_instance_data_chunks,
        data_config=config.data_config,
        user_instances_only=user_instances_only,
        max_hw=(max_height, max_width),
    )
    
    for dataset_name, labels in [("train", train_labels), ("val", val_labels)]:
        output_dir = (Path(args.dir_path) / f"{dataset_name}_chunks").as_posix()
        process_dataset(
            factory_get_chunks,
            labels,
            output_dir,
            args.num_workers,
            args.chunk_size,
        )

144-147: πŸ› οΈ Refactor suggestion

Consider using an enum for model types and validating earlier.

Moving the model type validation earlier and using an enum would improve maintainability.

Here's a suggested improvement:

from enum import Enum, auto

class ModelType(Enum):
    SINGLE_INSTANCE = "single_instance"
    CENTERED_INSTANCE = "centered_instance"
    CENTROID = "centroid"
    BOTTOMUP = "bottomup"
    
    @classmethod
    def from_str(cls, value: str) -> "ModelType":
        try:
            return cls(value)
        except ValueError:
            raise ValueError(
                f"{value} is not defined. Please choose one of: "
                f"{', '.join(m.value for m in cls)}"
            )

# In argument parser:
parser.add_argument(
    "--model_type",
    type=ModelType.from_str,
    required=True,
    help=f"Type of model. Choices: {', '.join(m.value for m in ModelType)}",
)

# Then use args.model_type.value in the if-elif chain

33-33: ⚠️ Potential issue

Simplify the boolean conversion.

The current boolean conversion is unnecessarily complex.

-    user_instances_only = False if args.user_instances_only == 0 else True
+    user_instances_only = bool(args.user_instances_only)
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    user_instances_only = bool(args.user_instances_only)
🧰 Tools
πŸͺ› Ruff

33-33: Use not ... instead of False if ... else True

Replace with not ...

(SIM211)


19-26: ⚠️ Potential issue

Improve argument parsing with help messages and proper types.

The current argument parsing could be enhanced with:

  • Help messages for each argument
  • Proper type validation
  • Required/optional argument specification
  • Default values where appropriate

Here's a suggested improvement:

-    parser = argparse.ArgumentParser()
-    parser.add_argument("--dir_path", type=str)
-    parser.add_argument("--user_instances_only", type=str)
-    parser.add_argument("--model_type", type=str)
-    parser.add_argument("--num_workers", type=int)
-    parser.add_argument("--chunk_size", type=int)
-    parser.add_argument("--crop_hw", type=int, default=None)
+    parser = argparse.ArgumentParser(description="Generate optimized binary files for model training.")
+    parser.add_argument(
+        "--dir_path",
+        type=Path,
+        required=True,
+        help="Directory containing initial_config.yaml",
+    )
+    parser.add_argument(
+        "--model_type",
+        type=str,
+        required=True,
+        choices=["single_instance", "centered_instance", "centroid", "bottomup"],
+        help="Type of model to generate data for",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        required=True,
+        help="Number of worker processes",
+    )
+    parser.add_argument(
+        "--chunk_size",
+        type=int,
+        required=True,
+        help="Size of data chunks",
+    )
+    parser.add_argument(
+        "--user_instances_only",
+        type=bool,
+        default=False,
+        help="Filter for user instances only",
+    )
+    parser.add_argument(
+        "--crop_hw",
+        type=int,
+        default=None,
+        help="Crop height/width (required for centered_instance)",
+    )
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    parser = argparse.ArgumentParser(description="Generate optimized binary files for model training.")
    parser.add_argument(
        "--dir_path",
        type=Path,
        required=True,
        help="Directory containing initial_config.yaml",
    )
    parser.add_argument(
        "--model_type",
        type=str,
        required=True,
        choices=["single_instance", "centered_instance", "centroid", "bottomup"],
        help="Type of model to generate data for",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        required=True,
        help="Number of worker processes",
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        required=True,
        help="Size of data chunks",
    )
    parser.add_argument(
        "--user_instances_only",
        type=bool,
        default=False,
        help="Filter for user instances only",
    )
    parser.add_argument(
        "--crop_hw",
        type=int,
        default=None,
        help="Crop height/width (required for centered_instance)",
    )
    args = parser.parse_args()
tests/data/test_streaming_datasets.py (1)

93-93: πŸ› οΈ Refactor suggestion

Consider consolidating duplicate test setup code.

The pattern of calculating max_hw and setting up the partial function is repeated across all test functions. Consider extracting this into a helper function to reduce code duplication.

+def setup_test_data_chunks(labels, dir_path, chunk_func, **kwargs):
+    """Helper function to set up test data chunks.
+    
+    Args:
+        labels: Labels object
+        dir_path: Output directory path
+        chunk_func: Data chunk function to use
+        **kwargs: Additional arguments for chunk function
+    """
+    max_hw = get_max_height_width(labels)
+    partial_func = functools.partial(
+        chunk_func,
+        data_config=config.data_config,
+        max_hw=max_hw,
+        **kwargs
+    )
+    ld.optimize(
+        fn=partial_func,
+        inputs=[(x, labels.videos.index(x.video)) for x in labels],
+        output_dir=str(dir_path),
+        chunk_size=4,
+    )
+    return max_hw

 def test_centered_instance_streaming_dataset(minimal_instance, sleap_data_dir, config):
     """Test CenteredInstanceStreamingDataset class."""
     labels = sio.load_slp(minimal_instance)
-    max_hw = get_max_height_width(labels)
     dir_path = Path(sleap_data_dir) / "data_chunks"
-    partial_func = functools.partial(
-        centered_instance_data_chunks,
-        data_config=config.data_config,
-        max_instances=2,
-        crop_size=(160, 160),
-        anchor_ind=0,
-        max_hw=max_hw,
-    )
-    ld.optimize(
-        fn=partial_func,
-        inputs=[(x, labels.videos.index(x.video)) for x in labels],
-        output_dir=str(dir_path),
-        chunk_size=4,
-    )
+    setup_test_data_chunks(
+        labels,
+        dir_path,
+        centered_instance_data_chunks,
+        max_instances=2,
+        crop_size=(160, 160),
+        anchor_ind=0
+    )

Also applies to: 103-103

sleap_nn/data/get_data_chunks.py (1)

150-152: πŸ› οΈ Refactor suggestion

Consider standardizing the normalization step.

The normalization is applied inside generate_crops, while other functions handle normalization differently. Consider standardizing where normalization occurs in the processing pipeline to maintain consistency.

Consider moving the normalization step outside generate_crops:

-        res = generate_crops(
-            apply_normalization(sample["image"]), instance, centroid, crop_size
-        )
+        normalized_image = apply_normalization(sample["image"])
+        res = generate_crops(normalized_image, instance, centroid, crop_size)
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        normalized_image = apply_normalization(sample["image"])
        res = generate_crops(normalized_image, instance, centroid, crop_size)
tests/inference/test_topdown.py (1)

139-146: πŸ› οΈ Refactor suggestion

Consider extracting duplicated example preparation code into a helper function.

The example preparation code is duplicated across multiple test functions. Consider extracting it into a helper function to improve maintainability and reduce code duplication.

def prepare_test_example(labels, frame_idx=0, stride=2):
    """Helper function to prepare test example with common fields."""
    ex = process_lf(labels[frame_idx], frame_idx, stride)
    ex["image"] = apply_normalization(ex["image"]).unsqueeze(dim=0)
    ex["instances"] = ex["instances"].unsqueeze(dim=0)
    ex["frame_idx"] = ex["frame_idx"].unsqueeze(dim=0)
    ex["video_idx"] = ex["video_idx"].unsqueeze(dim=0)
    ex["orig_size"] = ex["orig_size"].unsqueeze(dim=0)
    ex["eff_scale"] = torch.Tensor([1.0])
    return ex
sleap_nn/data/streaming_datasets.py (2)

70-73: πŸ› οΈ Refactor suggestion

Consider optimizing the image transformation pipeline.

The transformation steps are consistent across all dataset classes but could be improved:

  1. Error handling: The code assumes the input is always a valid PIL image. Consider adding validation.
  2. Performance: These operations could be combined into a single transform using torch.nn.Sequential.
  3. Memory efficiency: The operations create multiple intermediate tensors.

Consider creating a reusable transform pipeline to reduce code duplication and improve efficiency:

class ImageTransformPipeline:
    def __init__(self):
        self.transform = T.Compose([
            T.PILToTensor(),
            lambda x: x.unsqueeze(0),
            apply_normalization
        ])
    
    def __call__(self, image):
        if not hasattr(image, 'mode'):  # Basic PIL image check
            raise ValueError("Expected PIL image")
        return self.transform(image)

Then use it in each dataset:

- transform = T.PILToTensor()
- ex["image"] = transform(ex["image"])
- ex["image"] = ex["image"].unsqueeze(dim=0)
- ex["image"] = apply_normalization(ex["image"])
+ transform = ImageTransformPipeline()
+ ex["image"] = transform(ex["image"])

Also applies to: 171-174, 272-275, 358-361


70-73: πŸ’‘ Codebase verification

After analyzing the test file content, I can now generate the final response.

Test coverage exists but needs expansion for new transformation pipeline

The existing tests already verify image shapes and basic transformations, but they don't explicitly test the new PILToTensor transformation and normalization steps. Current test coverage:

  • βœ“ Image shape verification after transformations
  • βœ“ Augmentation pipeline (random crop)
  • ❌ PILToTensor conversion
  • ❌ Normalization validation
  • ❌ Error handling for invalid inputs

Suggested test additions:

  • Add test cases in tests/data/test_streaming_datasets.py for each dataset class to verify:
    • Correct conversion from PIL to tensor format
    • Proper normalization of pixel values (range and distribution)
    • Error handling for invalid image inputs
πŸ”— Analysis chain

Ensure test coverage for the new transformation pipeline.

The addition of image transformation steps across all dataset classes requires corresponding test updates to verify:

  1. Correct tensor shape and type after transformation
  2. Proper normalization of pixel values
  3. Error handling for invalid inputs
  4. Memory usage patterns

Run the following to check existing test coverage:

Also applies to: 171-174, 272-275, 358-361


Let me gather more information about the test coverage by examining the test file content and checking for any existing image-related tests.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Look for related test files
fd "test.*streaming_datasets.*\.py"

# Check for existing image transformation tests
rg -A 5 "def test.*transform" $(fd "test.*streaming_datasets.*\.py")

Length of output: 390


Script:

#!/bin/bash
# Check content of the test file
cat tests/data/test_streaming_datasets.py

# Look for any image-related test cases
rg -i "test.*image" tests/data/test_streaming_datasets.py

# Look for normalization tests
rg -i "test.*normal" tests/data/test_streaming_datasets.py

Length of output: 8612

sleap_nn/data/resizing.py (2)

118-122: πŸ› οΈ Refactor suggestion

Consider resizing images larger than max dimensions instead of raising an error

Currently, the function raises a ValueError when the image dimensions exceed the maximum allowed dimensions. To enhance the function's robustness, consider automatically scaling down images larger than the maximum dimensions to fit within the desired size.

Apply this diff to modify the function:

 def apply_sizematcher(
     image: torch.Tensor,
     max_height: Optional[int] = None,
     max_width: Optional[int] = None,
 ):
     """Apply scaling and padding to image to (max_height, max_width) shape."""
     img_height, img_width = image.shape[-2:]
     # Set default max dimensions
     if max_height is None:
         max_height = img_height
     if max_width is None:
         max_width = img_width

-    if img_height > max_height:
-        raise ValueError(
-            f"Max height {max_height} should be greater than the current image height: {img_height}"
-        )
-
-    if img_width > max_width:
-        raise ValueError(
-            f"Max width {max_width} should be greater than the current image width: {img_width}"
-        )
-
-    if img_height < max_height or img_width < max_width:
+    if img_height != max_height or img_width != max_width:
         hratio = max_height / img_height
         wratio = max_width / img_width

         if hratio > wratio:
             eff_scale_ratio = wratio
             target_h = int(round(img_height * wratio))
             target_w = int(round(img_width * wratio))
         else:
             eff_scale_ratio = hratio
             target_w = int(round(img_width * hratio))
             target_h = int(round(img_height * hratio))

         image = tvf.resize(image, size=(target_h, target_w))

         pad_height = max_height - target_h
         pad_width = max_width - target_w

         image = F.pad(
             image,
             (0, pad_width, 0, pad_height),
             mode="constant",
         ).to(torch.float32)

         return image, eff_scale_ratio
     else:
         return image, 1.0
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    if img_height != max_height or img_width != max_width:
        hratio = max_height / img_height
        wratio = max_width / img_width

        if hratio > wratio:
            eff_scale_ratio = wratio
            target_h = int(round(img_height * wratio))
            target_w = int(round(img_width * wratio))
        else:
            eff_scale_ratio = hratio
            target_w = int(round(img_width * hratio))
            target_h = int(round(img_height * hratio))

        image = tvf.resize(image, size=(target_h, target_w))

        pad_height = max_height - target_h
        pad_width = max_width - target_w

        image = F.pad(
            image,
            (0, pad_width, 0, pad_height),
            mode="constant",
        ).to(torch.float32)

        return image, eff_scale_ratio
    else:
        return image, 1.0

126-152: ⚠️ Potential issue

Ensure consistent scaling and padding logic

The current logic only scales and pads images that are smaller than the maximum dimensions. By modifying the condition to include images that are larger, the function will handle all image sizes consistently, scaling and padding as needed.

Apply this diff to adjust the condition:

-    if img_height < max_height or img_width < max_width:
+    if img_height != max_height or img_width != max_width:
         hratio = max_height / img_height
         wratio = max_width / img_width
         # Remaining code unchanged...
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    if img_height != max_height or img_width != max_width:
        hratio = max_height / img_height
        wratio = max_width / img_width

        if hratio > wratio:
            eff_scale_ratio = wratio
            target_h = int(round(img_height * wratio))
            target_w = int(round(img_width * wratio))
        else:
            eff_scale_ratio = hratio
            target_w = int(round(img_width * hratio))
            target_h = int(round(img_height * hratio))

        image = tvf.resize(image, size=(target_h, target_w))

        pad_height = max_height - target_h
        pad_width = max_width - target_w

        image = F.pad(
            image,
            (0, pad_width, 0, pad_height),
            mode="constant",
        ).to(torch.float32)

        return image, eff_scale_ratio
    else:
        return image, 1.0
sleap_nn/inference/topdown.py (1)

69-69: ⚠️ Potential issue

Potential issue with crop_hw defaulting to None

By changing crop_hw to default to None, ensure that the code handles cases where crop_hw is None. In the _generate_crops method, accessing self.crop_hw[0] and self.crop_hw[1] will raise an exception if crop_hw is None.

Apply this diff to set a default value when crop_hw is None:

 def __init__(
     self,
     torch_model: Optional[L.LightningModule] = None,
     output_stride: int = 1,
     peak_threshold: float = 0.0,
     max_instances: Optional[int] = None,
     refinement: Optional[str] = None,
     integral_patch_size: int = 5,
     return_confmaps: bool = False,
     return_crops: bool = False,
     crop_hw: Optional[List[int]] = None,
     input_scale: float = 1.0,
     max_stride: int = 1,
     use_gt_centroids: bool = False,
     anchor_ind: Optional[int] = None,
     **kwargs,
 ):
     """Initialise the model attributes."""
     super().__init__(**kwargs)
+    if crop_hw is None:
+        crop_hw = [160, 160]
     self.torch_model = torch_model
     self.peak_threshold = peak_threshold
     self.refinement = refinement
     self.integral_patch_size = integral_patch_size
     self.output_stride = output_stride
     self.return_confmaps = return_confmaps
     self.max_instances = max_instances
     self.return_crops = return_crops
     self.crop_hw = crop_hw
     self.input_scale = input_scale
     self.max_stride = max_stride
     self.use_gt_centroids = use_gt_centroids
     self.anchor_ind = anchor_ind
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        def __init__(
            self,
            torch_model: Optional[L.LightningModule] = None,
            output_stride: int = 1,
            peak_threshold: float = 0.0,
            max_instances: Optional[int] = None,
            refinement: Optional[str] = None,
            integral_patch_size: int = 5,
            return_confmaps: bool = False,
            return_crops: bool = False,
            crop_hw: Optional[List[int]] = None,
            input_scale: float = 1.0,
            max_stride: int = 1,
            use_gt_centroids: bool = False,
            anchor_ind: Optional[int] = None,
            **kwargs,
        ):
            """Initialise the model attributes."""
            super().__init__(**kwargs)
            if crop_hw is None:
                crop_hw = [160, 160]
            self.torch_model = torch_model
            self.peak_threshold = peak_threshold
            self.refinement = refinement
            self.integral_patch_size = integral_patch_size
            self.output_stride = output_stride
            self.return_confmaps = return_confmaps
            self.max_instances = max_instances
            self.return_crops = return_crops
            self.crop_hw = crop_hw
            self.input_scale = input_scale
            self.max_stride = max_stride
            self.use_gt_centroids = use_gt_centroids
            self.anchor_ind = anchor_ind
sleap_nn/training/model_trainer.py (6)

175-175: ⚠️ Potential issue

Remove unnecessary 'f' prefix from string literal

The string on line 175 does not contain any placeholders, so the f prefix is unnecessary.

Apply this diff to correct the string literal:

- f"sleap_nn.training.get_bin_files",
+ "sleap_nn.training.get_bin_files",
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

                    "sleap_nn.training.get_bin_files",
🧰 Tools
πŸͺ› Ruff

175-175: f-string without any placeholders

Remove extraneous f prefix

(F541)


10-10: ⚠️ Potential issue

Remove unused import 'shutil'

The shutil module is imported but not used in the code. Eliminating it will tidy up the imports.

Apply this diff to remove the unused import:

- import shutil
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
πŸͺ› Ruff

10-10: shutil imported but unused

Remove unused import: shutil

(F401)


6-6: ⚠️ Potential issue

Remove unused import 'inspect'

The inspect module is imported but not used anywhere in the code. Removing it will clean up unnecessary imports.

Apply this diff to remove the unused import:

- import inspect
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
πŸͺ› Ruff

6-6: inspect imported but unused

Remove unused import: inspect

(F401)


57-57: ⚠️ Potential issue

Remove unused import 'get_bin_files'

The get_bin_files module from sleap_nn.training is imported but not utilized. Removing it reduces clutter.

Apply this diff to remove the unused import:

- from sleap_nn.training import get_bin_files
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
πŸͺ› Ruff

57-57: sleap_nn.training.get_bin_files imported but unused

Remove unused import: sleap_nn.training.get_bin_files

(F401)


205-205: ⚠️ Potential issue

Use 'raise ... from e' to preserve the exception context

When re-raising exceptions, using raise ... from e preserves the original traceback, aiding in debugging.

Apply this diff to improve exception handling:

- raise Exception(f"Error while creating the `.bin` files... {e}")
+ raise Exception("Error while creating the `.bin` files") from e
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

            raise Exception("Error while creating the `.bin` files") from e
🧰 Tools
πŸͺ› Ruff

205-205: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


171-191: πŸ› οΈ Refactor suggestion

Simplify subprocess execution using 'subprocess.run'

Since you are running the subprocess and waiting for it to finish, subprocess.run provides a simpler and more reliable interface with built-in error handling.

Refactor the subprocess call as follows:

- def run_subprocess():
-     process = subprocess.Popen(
-         [
-             "python",
-             "-m",
-             "sleap_nn.training.get_bin_files",
-             "--dir_path",
-             f"{self.dir_path}",
-             "--user_instances_only",
-             "1" if self.user_instances_only else "0",
-             "--model_type",
-             f"{self.model_type}",
-             "--num_workers",
-             f"{self.config.trainer_config.train_data_loader.num_workers}",
-             "--chunk_size",
-             f"{self.chunk_size}",
-             "--crop_hw",
-             f"{self.crop_hw}",
-         ],
-         stdout=subprocess.PIPE,
-         stderr=subprocess.PIPE,
-         text=True,
-     )
-     # Use communicate() to read output and avoid hanging
-     stdout, stderr = process.communicate()
-     # Print the logs
-     print("Standard Output:\n", stdout)
-     print("Standard Error:\n", stderr)
+ def run_subprocess():
+     result = subprocess.run(
+         [
+             "python",
+             "-m",
+             "sleap_nn.training.get_bin_files",
+             "--dir_path",
+             self.dir_path,
+             "--user_instances_only",
+             "1" if self.user_instances_only else "0",
+             "--model_type",
+             self.model_type,
+             "--num_workers",
+             str(self.config.trainer_config.train_data_loader.num_workers),
+             "--chunk_size",
+             str(self.chunk_size),
+             "--crop_hw",
+             str(self.crop_hw),
+         ],
+         capture_output=True,
+         text=True,
+         check=True,
+     )
+     # Log the outputs
+     logging.info(f"Standard Output:\n{result.stdout}")
+     logging.error(f"Standard Error:\n{result.stderr}")

This refactoring uses subprocess.run with check=True to automatically raise a CalledProcessError if the subprocess exits with a non-zero status, simplifying error handling.

Committable suggestion was skipped due to low confidence.

🧰 Tools
πŸͺ› Ruff

175-175: f-string without any placeholders

Remove extraneous f prefix

(F541)

sleap_nn/inference/predictors.py (2)

273-279: πŸ› οΈ Refactor suggestion

Refactor duplicate image color conversion logic

The code for converting images to RGB or grayscale is duplicated in lines 273-279 and 304-307. To improve maintainability and reduce redundancy, consider refactoring this logic into a separate helper function or moving it outside the loop if possible.

Apply this diff to eliminate the duplication:

@@ -273,14 +273,6 @@
     if self.instances_key:
         frame["instances"] = frame["instances"] * eff_scale
-    if self.preprocess_config["is_rgb"] and frame["image"].shape[-3] != 3:
-        frame["image"] = frame["image"].repeat(1, 3, 1, 1)
-    elif not self.preprocess_config["is_rgb"]:
-        frame["image"] = F.rgb_to_grayscale(
-            frame["image"], num_output_channels=1
-        )
-    eff_scales.append(torch.tensor(eff_scale))
     imgs.append(frame["image"].unsqueeze(dim=0))
     fidxs.append(frame["frame_idx"])
     vidxs.append(frame["video_idx"])
@@ -304,7 +296,14 @@
         ex["instances"] = instances
+    # Perform image color conversion
+    if self.preprocess_config["is_rgb"] and ex["image"].shape[-3] != 3:
+        ex["image"] = ex["image"].repeat(1, 1, 3, 1, 1)
+    elif not self.preprocess_config["is_rgb"]:
+        ex["image"] = F.rgb_to_grayscale(ex["image"], num_output_channels=1)
+
     if self.preprocess:

Also applies to: 304-307


636-645: πŸ› οΈ Refactor suggestion

Consolidate preprocessing configuration across predictors

Similar code for setting up preprocess_config with max_height and max_width appears in multiple predictor classes (TopDownPredictor, SingleInstancePredictor, BottomUpPredictor). Consider consolidating this logic into a shared method or base class to reduce redundancy and ensure consistency across different predictors.

Also applies to: 657-666, 693-702, 962-971, 990-999, 1289-1298, 1317-1326

@codecov
Copy link

codecov bot commented Oct 23, 2024

Codecov Report

All modified and coverable lines are covered by tests βœ…

Project coverage is 97.39%. Comparing base (f093ce2) to head (585bc0e).
Report is 19 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #106      +/-   ##
==========================================
+ Coverage   96.64%   97.39%   +0.74%     
==========================================
  Files          23       38      +15     
  Lines        1818     3681    +1863     
==========================================
+ Hits         1757     3585    +1828     
- Misses         61       96      +35     

β˜” View full report in Codecov by Sentry.
πŸ“’ Have feedback on the report? Share it here.

@gitttt-1234 gitttt-1234 force-pushed the divya/fix-get-chunks branch 2 times, most recently from 22c7a96 to 105e496 Compare October 30, 2024 19:03
@gitttt-1234 gitttt-1234 force-pushed the divya/add-thresh-mode-option branch from 162e949 to f64b425 Compare October 30, 2024 19:41
@gitttt-1234 gitttt-1234 changed the base branch from divya/fix-get-chunks to main October 30, 2024 19:42
@gitttt-1234 gitttt-1234 force-pushed the divya/add-thresh-mode-option branch from f64b425 to 162e949 Compare October 30, 2024 19:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (11)
tests/data/test_streaming_datasets.py (2)

31-35: Consider adding edge case tests for max_hw parameter.

While the test coverage for different scales is good, consider adding test cases for:

  • Edge cases where max_hw is very small or large
  • Cases where image dimensions are smaller than max_hw

Also applies to: 69-75


125-126: Consider adding scale=1.0 test case for consistency.

Other test functions include test cases for both scale=0.5 and scale=1.0. Consider adding a similar second test case here for consistency and better coverage.

sleap_nn/data/get_data_chunks.py (2)

272-276: Enhance documentation for max_hw parameter.

While the documentation explains the parameter well, consider adding an example showing how the values from config and max_hw interact to determine the final dimensions.

Example addition to the docstring:

Example:
    If config has max_height=1000, max_width=None and max_hw=(800, 1200),
    the final maximum dimensions will be (1000, 1200).

Line range hint 1-318: Consider extracting common image processing logic.

The file shows repeated patterns for image processing across all functions:

  1. Size matching with max_hw
  2. Coordinate scaling
  3. Additional resizing
  4. PIL conversion

Consider extracting these steps into a utility function to reduce code duplication and ensure consistent processing across all data chunk functions.

Example utility function:

def process_image_and_coordinates(
    image: torch.Tensor,
    coordinates: torch.Tensor,
    max_height: Optional[int],
    max_width: Optional[int],
    max_hw: Tuple[int, int],
    scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
    """Process image and scale coordinates accordingly."""
    # Size matching
    image, eff_scale = apply_sizematcher(
        image,
        max_height=max_height if max_height is not None else max_hw[0],
        max_width=max_width if max_width is not None else max_hw[1],
    )
    coordinates = coordinates * eff_scale

    # Additional resizing
    image, coordinates = apply_resizer(image, coordinates, scale=scale)
    
    # PIL conversion
    transform = T.ToPILImage()
    image = transform(image.squeeze(dim=0))
    
    return image, coordinates, eff_scale
sleap_nn/data/streaming_datasets.py (2)

131-131: Enhance documentation for input_scale parameter.

While the input_scale parameter is correctly implemented, its documentation could be more detailed to explain:

  • The valid range of values
  • How it affects the final crop size
  • Any performance implications

Consider expanding the docstring:

-        input_scale: Resize factor applied to the image. Default: 1.0.
+        input_scale: Scaling factor applied to crop dimensions. Values > 1.0 increase crop size,
+            while values < 1.0 decrease it. Must be positive. Default: 1.0.

Also applies to: 141-141, 153-153


322-324: LGTM: Image preprocessing is consistent with other classes.

The transformation pipeline correctly follows the established pattern.

Consider extracting the common image preprocessing logic into a shared utility function to reduce code duplication across dataset classes. Example:

def preprocess_image(image: PIL.Image, apply_aug: bool = False, aug_config: Optional[DictConfig] = None) -> torch.Tensor:
    """Apply standard image preprocessing pipeline.
    
    Args:
        image: Input PIL image
        apply_aug: Whether to apply augmentation
        aug_config: Augmentation configuration
        
    Returns:
        Preprocessed image tensor
    """
    transform = T.PILToTensor()
    image = transform(image)
    image = image.unsqueeze(dim=0).to(torch.float32)
    
    if apply_aug:
        if "intensity" in aug_config:
            image, _ = apply_intensity_augmentation(image, None, **aug_config.intensity)
        if "geometric" in aug_config:
            image, _ = apply_geometric_augmentation(image, None, **aug_config.geometric)
            
    image = apply_normalization(image)
    return image

Also applies to: 338-338

tests/training/test_model_trainer.py (1)

356-356: LGTM! Consider adding edge cases for RGB configuration.

The RGB configuration changes are well-coordinated between preprocessing and model configuration. However, consider adding test cases for:

  1. Mismatched configurations (is_rgb=True with in_channels=1)
  2. Grayscale to RGB conversion behavior

Also applies to: 362-362

sleap_nn/training/model_trainer.py (3)

92-92: Consider using None as default for crop_hw

Using -1 as a default value for crop_hw could be misleading. Consider using None instead, as it better represents an uninitialized state and matches Python conventions.

-        self.crop_hw = -1
+        self.crop_hw = None

382-388: Consider atomic file operations for config saving

The config saving operations could potentially fail if there are concurrent writes. Consider using atomic file operations to ensure config files are written safely.

Consider using a temporary file and atomic rename operation:

import tempfile
import os

def save_config_atomic(config, filepath):
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as tf:
        OmegaConf.save(config=config, f=tf.name)
        os.replace(tf.name, filepath)

Remove unused inspect import and keep other imports

The inspect module is imported but never used in the code. However, both shutil and get_bin_files are actively used:

  • shutil is used in commented-out cleanup code and is commonly used across tests
  • get_bin_files is used in subprocess calls for data processing

Remove only the inspect import to maintain clean imports while preserving required functionality.

πŸ”— Analysis chain

Line range hint 6-57: Clean up unused imports

The following imports are flagged as unused and should be removed if they're not needed elsewhere in the codebase:

  • inspect (line 6)
  • shutil (line 10)
  • sleap_nn.training.get_bin_files (line 57)

Let me analyze the usage of these imports more thoroughly before making a final decision.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Verify if these imports are used elsewhere
rg "\b(inspect|shutil|get_bin_files)\b" --type py

Length of output: 4076


Script:

#!/bin/bash
# Check for actual usage of inspect and get_bin_files in model_trainer.py
rg -A 5 "\b(inspect|get_bin_files)\b" sleap_nn/training/model_trainer.py

# Check for shutil usage in model_trainer.py
rg -A 5 "\bshutil\." sleap_nn/training/model_trainer.py

Length of output: 1122

🧰 Tools
πŸͺ› Ruff

37-37: torchvision.models.convnext.ConvNeXt_Small_Weights imported but unused

Remove unused import

(F401)


38-38: torchvision.models.convnext.ConvNeXt_Large_Weights imported but unused

Remove unused import

(F401)


40-40: Redefinition of unused sio from line 13

Remove definition: sio

(F811)


42-42: sleap_nn.data.cycler.CyclerIterDataPipe imported but unused

Remove unused import: sleap_nn.data.cycler.CyclerIterDataPipe

(F401)


46-46: sleap_nn.data.get_data_chunks.bottomup_data_chunks imported but unused

Remove unused import

(F401)


47-47: sleap_nn.data.get_data_chunks.centered_instance_data_chunks imported but unused

Remove unused import

(F401)

sleap_nn/inference/predictors.py (1)

1624-1626: Fix indentation for better readability.

The indentation of this block appears to be inconsistent with the surrounding code.

Apply this diff to fix the indentation:

-    if isinstance(predictor, TopDownPredictor):
-        predictor.anchor_ind = anchor_ind
-

+    if isinstance(predictor, TopDownPredictor):
+        predictor.anchor_ind = anchor_ind
+
πŸ“œ Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 0b964a9 and 162e949.

πŸ“’ Files selected for processing (8)
  • sleap_nn/data/get_data_chunks.py (13 hunks)
  • sleap_nn/data/streaming_datasets.py (11 hunks)
  • sleap_nn/inference/predictors.py (16 hunks)
  • sleap_nn/training/get_bin_files.py (1 hunks)
  • sleap_nn/training/model_trainer.py (9 hunks)
  • tests/data/test_streaming_datasets.py (11 hunks)
  • tests/fixtures/datasets.py (2 hunks)
  • tests/training/test_model_trainer.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/fixtures/datasets.py
🧰 Additional context used
πŸͺ› Ruff
sleap_nn/training/get_bin_files.py

34-34: Use not ... instead of False if ... else True

Replace with not ...

(SIM211)

sleap_nn/training/model_trainer.py

6-6: inspect imported but unused

Remove unused import: inspect

(F401)


10-10: shutil imported but unused

Remove unused import: shutil

(F401)


57-57: sleap_nn.training.get_bin_files imported but unused

Remove unused import: sleap_nn.training.get_bin_files

(F401)


174-174: f-string without any placeholders

Remove extraneous f prefix

(F541)


206-206: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

πŸ”‡ Additional comments (14)
tests/data/test_streaming_datasets.py (2)

13-13: LGTM: Consistent integration of max_hw calculation.

The addition of get_max_height_width import and its consistent usage across all test functions is well-structured. The calculation is properly placed before data processing in each test function.

Also applies to: 25-25, 115-115, 160-160, 247-247


169-170: LGTM: Comprehensive test coverage for centroid streaming.

The test properly covers both scaling scenarios and includes augmentation testing. The structure is consistent with other tests and includes proper cleanup.

Also applies to: 202-209

sleap_nn/data/get_data_chunks.py (4)

7-7: LGTM: Import addition aligns with new image transformation requirements.

The addition of torchvision.transforms is appropriate for the new PIL image conversion functionality implemented across all data chunk functions.


161-163: Good: Explicit float32 conversion before crop generation.

The explicit conversion to float32 before crop generation helps prevent potential precision issues. The subsequent resizing and PIL conversion are properly applied to each crop.

Also applies to: 170-176


237-238: LGTM: Proper centroid coordinate scaling.

The implementation correctly scales centroid coordinates using both the effective scale from size matching and the additional scale parameter. The order of operations (scaling β†’ resizing β†’ PIL conversion) is maintained consistently.

Also applies to: 244-251


24-26: Verify the impact of double scaling on instance coordinates.

The implementation looks correct, but we should verify that the combination of eff_scale from apply_sizematcher and the additional scale parameter in apply_resizer maintains proper coordinate scaling for instances.

Let's verify the scaling implementation:

Also applies to: 63-83

βœ… Verification successful

Scaling implementation is correct and properly tested

The implementation and tests show that:

  1. apply_resizer correctly scales both images and instance coordinates by the same factor
  2. Test case explicitly verifies that instance coordinates are scaled properly: assert torch.all(instances == ex["instances"] * 2.0)
  3. The scaling operations in apply_sizematcher and apply_resizer are independent and not compounded - they are applied in sequence with proper coordinate transformation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any scaling-related issues in tests or related files

# Look for test cases involving scaling
rg -A 5 "scale.*=.*" "tests/data/"

# Check for any reported scaling issues
rg -i "scale.*issue|coordinate.*mismatch" "tests/" "issues/"

Length of output: 5553


Script:

#!/bin/bash
# Let's check the implementation of apply_sizematcher and apply_resizer to verify scaling logic

# Check implementation of size matching and scaling
ast-grep --pattern 'def apply_sizematcher($$$)'

# Check implementation of resizer
ast-grep --pattern 'def apply_resizer($$$)'

# Look for any tests specifically verifying coordinate scaling
rg -B 2 -A 5 "assert.*instances.*scale|assert.*coordinate" "tests/"

Length of output: 163


Script:

#!/bin/bash
# Let's try a different approach to find the scaling implementation and tests

# Search for files containing scaling-related functions
rg -l "apply_sizematcher|apply_resizer" --type py

# Once we find the files, let's look at their content
rg -B 5 -A 10 "def apply_sizematcher|def apply_resizer" --type py

# Check test assertions related to scaling
rg -B 2 -A 5 "assert.*\*.*2\.0|assert.*scale" "tests/data/test_resizing.py"

Length of output: 2588

sleap_nn/data/streaming_datasets.py (4)

8-8: LGTM: Import changes are well-organized.

The new imports for torchvision.transforms and apply_normalization are appropriately placed and support the new image transformation pipeline.

Also applies to: 16-16


65-67: LGTM: Image preprocessing pipeline is well-structured.

The transformation sequence follows best practices:

  1. PIL to Tensor conversion
  2. Proper tensor formatting (batch dimension + float32)
  3. Augmentation (when enabled)
  4. Normalization
  5. Stride padding

Also applies to: 81-81


247-249: LGTM: Consistent image preprocessing implementation.

The transformation pipeline matches the pattern established in other dataset classes, maintaining consistency across the codebase.

Also applies to: 263-263


Line range hint 1-354: Verify PR scope and description.

The changes in this file focus on standardizing image preprocessing across dataset classes, but the PR description mentions adding a threshold mode for the learning rate scheduler. This seems to be a mismatch between the PR objective and the actual changes.

Let's check if there are other files that implement the threshold mode:

sleap_nn/training/model_trainer.py (1)

Line range hint 624-630: LGTM: Learning rate scheduler threshold mode implementation

The addition of the threshold_mode parameter to the learning rate scheduler is well-implemented and aligns with the PR objectives.

sleap_nn/inference/predictors.py (3)

78-79: LGTM: Clean addition of max dimension parameters.

The new optional parameters max_height and max_width are properly added to the base configuration dictionary, maintaining backward compatibility.


265-279: LGTM: Well-structured image preprocessing pipeline.

The preprocessing pipeline is properly implemented with:

  1. Correct ordering of operations (resize -> color conversion -> normalization)
  2. Modern usage of torchvision.transforms.v2 for color space conversion
  3. Proper handling of effective scale for instance coordinates

427-427: LGTM: Clean implementation of anchor_ind and optional crop_hw.

The changes properly:

  1. Add optional anchor_ind with good fallback to config value
  2. Make crop_hw optional with appropriate fallback
  3. Maintain backward compatibility

Also applies to: 443-456

@gitttt-1234 gitttt-1234 force-pushed the divya/add-thresh-mode-option branch from 162e949 to 62e7406 Compare October 30, 2024 20:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
sleap_nn/training/model_trainer.py (2)

Line range hint 206-206: Improve error handling in subprocess execution.

The current error handling could be more specific and provide better context for debugging.

Apply this improvement:

-            raise Exception(f"Error while creating the `.bin` files... {e}")
+            raise Exception("Failed to create .bin files: subprocess execution error") from e

Based on the gathered context, I can now provide a final response:

The commented-out cleanup code should be re-enabled with proper resource management

The test failure on Ubuntu is related to the subprocess spawned by _create_data_loaders() which runs for over 6 hours. The issue isn't with the cleanup code itself, but rather with the data chunk generation process that's using multiprocessing workers. The cleanup code is actually necessary to prevent resource leaks.

Recommendations:

  • Re-enable the cleanup code as it's essential for proper resource management
  • The actual issue needs to be addressed in the data chunk generation process by:
    • Adding timeouts to the subprocess call
    • Limiting the number of workers in the test configuration
    • Reducing the test data size for this specific test
πŸ”— Analysis chain

Line range hint 447-456: Address the TODO comment about Ubuntu test failure.

The commented-out cleanup code and TODO comment indicate an unresolved issue with the Ubuntu test that's running for over 6 hours.

Let's investigate the cleanup issue:

Would you like help implementing a more robust cleanup mechanism that won't cause test timeouts?


Let me gather more context about the TODO comment and test failures.


Let me check the specific code around the TODO comment.


Let me check the test file to understand the test configuration and the TODO comment better.


Let me check the test file for the specific test that's being skipped and its configuration.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential file handle leaks or long-running processes

# Test: Look for any file operations or long-running processes in the codebase
rg -A 5 "shutil\.(rmtree|copy|move)" 
rg -A 5 "subprocess\.(Popen|run|call)"

Length of output: 11815


Script:

#!/bin/bash
# Look for the TODO comment about Ubuntu test failure
rg -B 5 -A 5 "TODO.*Ubuntu" 

# Check the test file for any timing related configurations
rg -l "test_model_trainer" | xargs cat

# Look for any process spawning or multiprocessing code
ast-grep --pattern 'subprocess.Popen($$$)'
ast-grep --pattern 'multiprocessing.$$$'

Length of output: 25812


Script:

#!/bin/bash
# Look for the code context around line 447-456 in model_trainer.py
rg -B 10 -A 10 "TODO.*Revisit this test later" sleap_nn/training/model_trainer.py

# Check for any multiprocessing configurations in the test
rg -A 5 "num_workers|multiprocessing" sleap_nn/training/model_trainer.py

# Look for the get_bin_files module that's spawned as a subprocess
fd -e py get_bin_files | xargs cat

Length of output: 6666


Script:

#!/bin/bash
# Look for the test file and the specific test that's being skipped
rg -B 5 -A 5 "@pytest.mark.skipif.*Ubuntu" tests/

# Check for any test configurations related to timeouts
rg -A 5 "timeout|duration" tests/

# Look for any cleanup related code in the test
rg -A 5 "cleanup|rmtree" tests/training/test_model_trainer.py

Length of output: 4439


Script:

#!/bin/bash
# Find the specific test file and the skipped test
rg -B 20 -A 20 "reason=\"Flaky test.*Ubuntu" tests/

# Check the model trainer's data loading and subprocess code
rg -B 10 -A 10 "_create_data_loaders" sleap_nn/training/model_trainer.py

# Look for any pytest configuration files
fd -e ini -e cfg -e conf pytest | xargs cat

Length of output: 4734

πŸ“œ Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 162e949 and 585bc0e.

πŸ“’ Files selected for processing (5)
  • docs/config_bottomup.yaml (1 hunks)
  • docs/config_centroid.yaml (1 hunks)
  • docs/config_topdown_centered_instance.yaml (1 hunks)
  • sleap_nn/training/model_trainer.py (1 hunks)
  • tests/fixtures/datasets.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • docs/config_bottomup.yaml
  • docs/config_centroid.yaml
  • docs/config_topdown_centered_instance.yaml
  • tests/fixtures/datasets.py
πŸ”‡ Additional comments (1)
sleap_nn/training/model_trainer.py (1)

624-624: LGTM: Learning rate scheduler threshold mode implementation.

The addition of threshold_mode parameter to the ReduceLROnPlateau scheduler is correctly implemented, allowing for both absolute and relative threshold modes as specified in the PR objectives.

@gitttt-1234 gitttt-1234 merged commit 00da7af into main Oct 30, 2024
7 checks passed
@gitttt-1234 gitttt-1234 deleted the divya/add-thresh-mode-option branch October 30, 2024 21:27
@coderabbitai coderabbitai bot mentioned this pull request Dec 5, 2024
@coderabbitai coderabbitai bot mentioned this pull request Mar 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants