Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@gqcpm
Copy link
Contributor

@gqcpm gqcpm commented Sep 27, 2024

This is the first PR of #75. Here we want to:

  • Migrate over the attrs config classes from sleap.nn.config, starting with TrainingJobConfig and moving down the hierarchy. These get migrated to a submodule under sleap-nn/sleap_nn/config.
  • Update class definitions to new attrs API.
  • Replace cattr serialization with OmegaConf.
  • Replace the functionality of the oneof decorator with OmegaConf-based routines if possible.
  • Implement cross-field validation for linked attributes.

Summary by CodeRabbit

Based on the comprehensive changes, here are the updated release notes:

  • New Features

    • Introduced advanced configuration management for machine learning workflows.
    • Added comprehensive configuration classes for data, model, and training parameters.
    • Implemented robust serialization and validation for configuration settings.
    • Added a utility function to enforce exclusive attribute settings in configuration classes.
    • Introduced a new TrainingJobConfig class for managing training job parameters.
  • Improvements

    • Enhanced configuration flexibility with support for multiple backbone architectures.
    • Added detailed validation for training and augmentation parameters.
    • Improved documentation for configuration classes.
    • Implemented utility functions for configuration handling.
  • Testing

    • Expanded test coverage for configuration classes.
    • Added comprehensive unit tests for configuration validation and initialization, including YAML handling.

These changes provide a more structured and flexible approach to configuring machine learning experiments in the SLEAP neural network library.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 27, 2024

Walkthrough

This pull request introduces a comprehensive configuration management system for machine learning workflows in the SLEAP-NN project. The changes involve creating multiple configuration classes across several files in the sleap_nn/config directory, focusing on data, model, training, and job configurations. These configurations utilize the attrs library for structured attribute management and validation, providing a robust and extensible approach to specifying training parameters. Additionally, extensive unit tests have been added to ensure the functionality and integrity of these configurations.

Changes

File Change Summary
sleap_nn/config/data_config.py Added configuration classes for data management, including DataConfig, PreprocessingConfig, AugmentationConfig, IntensityConfig, and GeometricConfig.
sleap_nn/config/model_config.py Introduced configuration classes for model architectures, including UNetConfig, ConvNextConfig, SwinTConfig, and various head configurations with enhanced documentation.
sleap_nn/config/trainer_config.py Enhanced configuration classes for training parameters, including DataLoaderConfig, ModelCkptConfig, WandBConfig, OptimizerConfig, LRSchedulerConfig, EarlyStoppingConfig, and TrainerConfig, with comprehensive docstrings.
sleap_nn/config/training_job_config.py Added TrainingJobConfig for managing complete training job configurations with YAML serialization support.
sleap_nn/config/utils.py Introduced oneof decorator for ensuring exclusive attribute settings, along with utility methods for attribute management.
tests/config/test_trainer_config.py Added unit tests for trainer configuration classes, validating default values and customization options.
tests/config/test_data_config.py Added unit tests for data configuration classes, focusing on initialization and validation logic.
tests/config/test_model_config.py Introduced tests for model configuration classes, verifying initialization and error handling.
tests/config/test_training_job_config.py Added tests for the TrainingJobConfig class, ensuring proper YAML handling and validation.
tests/config/test_config_utils.py Introduced tests for the oneof utility function, validating exclusive attribute settings.

Possibly related issues

  • Implement a schema for configs #75: Implement a schema for configs - This PR directly addresses the need for structured configuration management with validation and serialization support.

Possibly related PRs

  • Refactor Augmentation config #67: The changes in the main PR, which introduce configuration classes for managing data parameters, are related to the retrieved PR that refactors the augmentation configuration, as both involve modifications to the augmentation_config structure and parameters.
  • Fix augmentation in TopdownConfmaps pipeline #78: The changes in the main PR, which introduce configuration classes for managing data parameters, are related to the retrieved PR that enhances the augmentation logic by restructuring how intensity and geometric augmentations are applied, as both involve modifications to the handling of augmentation configurations.
  • Add minimal pretrained checkpoints for tests and fix PAF grouping interpolation #73: The changes in the main PR, which introduce new configuration classes and parameters for data management in machine learning workflows, are related to the modifications in the retrieved PR that restructure and clarify parameters in the data_config, particularly the train_labels_path, val_labels_path, and augmentation_config, as they both involve enhancements to the configuration structure and parameters for data handling.

Suggested reviewers

  • talmo

Poem

🐰 Configuration Rabbit's Delight
In SLEAP-NN's config garden bright,
Attrs dance with validation's might,
Classes nested, configs clear,
Machine learning's path now sincere!
Hop, hop, config rabbit goes! 🌟

✨ Finishing Touches
  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@codecov
Copy link

codecov bot commented Sep 27, 2024

Codecov Report

Attention: Patch coverage is 95.12894% with 17 lines in your changes missing coverage. Please review.

Project coverage is 97.08%. Comparing base (f093ce2) to head (c448afa).
Report is 45 commits behind head on main.

Files with missing lines Patch % Lines
sleap_nn/config/utils.py 76.66% 7 Missing ⚠️
sleap_nn/config/model_config.py 96.69% 4 Missing ⚠️
sleap_nn/config/trainer_config.py 96.87% 3 Missing ⚠️
sleap_nn/config/training_job_config.py 95.00% 2 Missing ⚠️
sleap_nn/config/data_config.py 98.38% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #97      +/-   ##
==========================================
+ Coverage   96.64%   97.08%   +0.43%     
==========================================
  Files          23       45      +22     
  Lines        1818     4489    +2671     
==========================================
+ Hits         1757     4358    +2601     
- Misses         61      131      +70     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (2)

1-2: Consider adding explicit attribute definitions.

The ModelConfig class is correctly using the @attr.s decorator with auto_attribs=True. However, the class body is empty, which is unusual. Consider adding explicit attribute definitions to improve code clarity and enable IDE autocompletion.

Here's an example of how you could define the attributes:

@attr.s(auto_attribs=True)
class ModelConfig:
    backbone: dict
    heads: dict
    base_checkpoint: str

This will make the class structure more explicit and easier to understand at a glance.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


3-9: Docstring looks good, minor formatting suggestion.

The docstring provides clear and informative descriptions of the class and its attributes. Well done!

Consider adding a period at the end of the last line for consistency:

-        base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file
+        base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file.
sleap_nn/config/training_job.py (1)

1-42: Overall assessment and recommendations

The TrainingJobConfig class provides a well-structured and documented foundation for managing training job configurations. However, there are a few improvements needed to make it complete and error-free:

  1. Add the missing imports for attr, DataConfig, and other required classes.
  2. Implement the remaining attributes mentioned in the class docstring.
  3. Consider adding type hints for all attributes to improve code readability and catch potential type-related issues early.
  4. If not already present in your project, consider adding a requirements.txt or setup.py file to specify the attrs library as a dependency.

Once these changes are made, the TrainingJobConfig class will be a robust and well-documented configuration management solution for your training jobs.

🧰 Tools
🪛 Ruff

27-27: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)


42-42: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c3be90b and 72b9d7b.

📒 Files selected for processing (3)
  • sleap_nn/config/data_config.py (1 hunks)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/training_job.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

1-1: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


11-11: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


12-12: Undefined name attr

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)


13-13: Undefined name attr

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)

sleap_nn/config/model_config.py

1-1: Undefined name attr

(F821)

sleap_nn/config/training_job.py

27-27: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)


42-42: Undefined name attr

(F821)


42-42: Undefined name DataConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (2)

3-9: Well-structured and informative docstring

The class docstring is well-written, providing a clear description of the class purpose and each of its attributes. It follows good documentation practices, which will help other developers understand the purpose and usage of this configuration class.


1-13: Overall well-structured and maintainable code

The DataConfig class is well-designed and follows good coding practices. It uses the attr library effectively for attribute management, has clear and descriptive naming, and includes informative docstrings. The suggested improvements (adding missing imports) are minor and easy to implement. Once these changes are made, the code will be more robust and easier to maintain.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


11-11: Undefined name attr

(F821)


11-11: Undefined name LabelsConfig

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


12-12: Undefined name attr

(F821)


12-12: Undefined name PreprocessingConfig

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)


13-13: Undefined name attr

(F821)


13-13: Undefined name InstanceCroppingConfig

(F821)

sleap_nn/config/training_job.py (1)

1-25: Excellent documentation!

The file-level docstring provides a clear and comprehensive explanation of the purpose and design principles behind the configuration classes. It effectively communicates the separation of configuration from implementation and highlights the benefits of this approach, including easy serialization and human-readability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)

1-3: Remove unused import

The List type from typing is imported but not used in the file. Consider removing it to keep the imports clean.

Apply this diff to remove the unused import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


38-51: Approve IntensityConfig class and suggest comment update

The IntensityConfig class is well-structured with appropriate attributes for intensity-related augmentation options. However, there's a comment indicating that this part is still under review.

Once the review of this section is complete, consider removing or updating the following comment:

# still looking over this part and below (function to check if is greater than 0)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 72b9d7b and 80e5915.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

🔇 Additional comments (2)
sleap_nn/config/data_config.py (2)

53-67: Approve GeometricConfig class

The GeometricConfig class is well-structured with appropriate attributes for geometric augmentation options. The default values seem reasonable, and the use of Optional types where appropriate is good practice.


1-67: Summary of review

Overall, the configuration classes in this file are well-structured and provide a comprehensive set of options for data handling, preprocessing, and augmentation in the SLEAP-NN project. The use of attr.s with auto_attribs=True is a good practice for creating classes with less boilerplate code.

Main points to address:

  1. Remove the unused List import.
  2. Fix circular import issues by using string annotations for types within the same file.
  3. Import the Union type from typing.
  4. Update or remove the comment in the IntensityConfig class once the review is complete.

After addressing these points, the file will be in excellent shape for use in the project.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/data_config.py (1)

1-3: Remove unused import

The List type is imported but not used in the file. Consider removing it to keep the imports clean.

Apply this diff to remove the unused import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 80e5915 and 186be74.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

38-61: Approve IntensityConfig class

The IntensityConfig class is well-structured with appropriate attributes for intensity augmentation options. The inclusion of validator methods for uniform_noise_min and uniform_noise_max is a good practice to ensure valid values.

The class implementation looks good and provides a comprehensive set of options for intensity-based augmentations.


63-77: Approve GeometricConfig class

The GeometricConfig class is well-structured with appropriate attributes for geometric augmentation options. The attribute names are clear and descriptive, and the types are suitable for their intended purposes.

The class implementation provides a comprehensive set of options for geometric augmentations, which should cover most common use cases in image augmentation pipelines.


1-77: Overall assessment: Well-structured configuration classes with minor improvements needed

The sleap_nn/config/data_config.py file introduces a set of well-structured configuration classes for data handling, preprocessing, and augmentation. The use of the attr library and type hints enhances code readability and maintainability.

Key points:

  1. The classes provide comprehensive options for various aspects of data processing and augmentation.
  2. The code follows good practices in terms of class structure and attribute definitions.
  3. Validator methods in the IntensityConfig class add an extra layer of safety.

Minor improvements suggested:

  1. Remove unused import (List).
  2. Add missing import (Union).
  3. Address potential circular import issues with string annotations.

Overall, the file provides a solid foundation for configuring data-related operations in the project. After addressing the minor issues, it will be ready for use.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


18-18: Undefined name PreprocessingConfig

(F821)


18-18: Undefined name PreprocessingConfig

(F821)


20-20: Undefined name AugmentationConfig

(F821)


28-28: Undefined name Union

(F821)


35-35: Undefined name IntensityConfig

(F821)


36-36: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/data_config.py (1)

1-3: Remove unused import

The List type is imported from typing but not used in the file. Consider removing it to keep the imports clean.

Apply this diff to remove the unused import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 186be74 and 796788d.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

64-103: LGTM: IntensityConfig class is well-structured

The IntensityConfig class is well-defined with appropriate attributes and validators for uniform noise parameters. The use of attr.s decorator with auto_attribs=True is a good practice for creating classes with less boilerplate code.


105-138: LGTM: GeometricConfig class is well-structured

The GeometricConfig class is well-defined with appropriate attributes for geometric augmentations. The use of attr.s decorator with auto_attribs=True is consistent with the other classes in this file.


1-138: Overall: Well-structured configuration classes for data handling

This file introduces a set of well-organized configuration classes for managing various aspects of data handling in machine learning workflows. The use of attr.s decorators with auto_attribs=True promotes clean and maintainable code. The classes cover essential aspects such as data sources, preprocessing, and augmentation options.

A few minor improvements have been suggested regarding imports and forward references. Once these are addressed, the file will provide a robust foundation for configuring data-related parameters in the project.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

@gqcpm gqcpm changed the title Greg/omegaconf basic func Implement Omegaconfig PR1: basic functionality Oct 9, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/data_config.py (2)

1-3: Optimize imports

Please make the following changes to the import statements:

  1. Remove the unused import of List.
  2. Add Union to the imports from typing.

Apply this diff to optimize the imports:

 import attrs
 from omegaconf import MISSING
-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict, Union

This change will remove the unused import and add the necessary Union type.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


94-127: Approve GeometricConfig class and fix typo

The GeometricConfig class is well-structured with appropriate attributes and default values. The docstring provides clear explanations for each attribute.

There's a small typo in the docstring for the erase_ratio_min attribute. Please fix it as follows:

-        erase_ration_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.
+        erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. Default: 1.

This change will correct the attribute name in the docstring.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 796788d and 34548b8.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

30-48: Approve PreprocessingConfig class and remind about Union import

The PreprocessingConfig class is well-structured with appropriate attributes and default values. The docstring provides clear explanations for each attribute.

Remember to add the Union import as suggested in the previous comment about optimizing imports.

🧰 Tools
🪛 Ruff

46-46: Undefined name Union

(F821)


64-92: Approve IntensityConfig class

The IntensityConfig class is well-structured with appropriate attributes, default values, and validators. The docstring provides clear explanations for each attribute. The use of attrs.field with validators for uniform_noise_min, uniform_noise_max, contrast_min, and contrast_max is a good practice to ensure valid input.


1-127: Overall assessment: Well-structured configuration classes with minor improvements needed

The data_config.py file introduces a set of well-structured configuration classes for managing data parameters in machine learning workflows. The use of the attrs library and detailed docstrings contributes to clean and self-documenting code.

Main points of improvement:

  1. Optimize imports by removing unused ones and adding missing ones.
  2. Use forward references to resolve potential circular import issues.
  3. Fix a minor typo in the GeometricConfig class docstring.

Once these changes are implemented, the file will be in excellent shape, providing a robust foundation for configuring data-related parameters in your project.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


25-25: Undefined name PreprocessingConfig

(F821)


25-25: Undefined name PreprocessingConfig

(F821)


27-27: Undefined name AugmentationConfig

(F821)


46-46: Undefined name Union

(F821)


61-61: Undefined name IntensityConfig

(F821)


62-62: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (4)
sleap_nn/config/data_config.py (4)

1-3: Remove unused import

The List type is imported from typing but not used in the file. To keep the imports clean, you can remove it.

Apply this change:

- from typing import Optional, Tuple, List, Dict
+ from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


12-35: Approve DataConfig class and suggest forward references

The DataConfig class is well-structured with clear attribute definitions and a comprehensive docstring. However, there are undefined names for PreprocessingConfig and AugmentationConfig.

To resolve the undefined names and prevent potential circular imports, use forward references for these classes. Apply this diff:

+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from sleap_nn.config.data_config import PreprocessingConfig, AugmentationConfig

 @attrs.define
 class DataConfig:
     # ... (existing code)
-    preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
-    augmentation_config: Optional[AugmentationConfig] = None
+    preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig())
+    augmentation_config: Optional['AugmentationConfig'] = None

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

33-33: Undefined name PreprocessingConfig

(F821)


33-33: Undefined name PreprocessingConfig

(F821)


35-35: Undefined name AugmentationConfig

(F821)


38-56: Approve PreprocessingConfig class and add missing import

The PreprocessingConfig class is well-structured with appropriate attributes for preprocessing options. However, the Union type is used but not imported.

Add the following import at the top of the file:

- from typing import Optional, Tuple, Dict
+ from typing import Optional, Tuple, Dict, Union

This will resolve the undefined Union type issue.

🧰 Tools
🪛 Ruff

54-54: Undefined name Union

(F821)


58-70: Approve AugmentationConfig class and use forward references

The AugmentationConfig class is well-structured, but there are undefined names for IntensityConfig and GeometricConfig.

To resolve the undefined names and prevent potential circular imports, use forward references for these classes. Apply this diff:

 @attrs.define
 class AugmentationConfig:
     # ... (existing code)
     random_crop: Optional[Dict[str, Optional[float]]] = None
-    intensity: Optional[IntensityConfig] = attrs.field(default=None)
-    geometric: Optional[GeometricConfig] = attrs.field(default=None)
+    intensity: Optional['IntensityConfig'] = attrs.field(default=None)
+    geometric: Optional['GeometricConfig'] = attrs.field(default=None)

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

69-69: Undefined name IntensityConfig

(F821)


70-70: Undefined name GeometricConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 34548b8 and 7ea94b6.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


33-33: Undefined name PreprocessingConfig

(F821)


33-33: Undefined name PreprocessingConfig

(F821)


35-35: Undefined name AugmentationConfig

(F821)


54-54: Undefined name Union

(F821)


69-69: Undefined name IntensityConfig

(F821)


70-70: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

72-100: Approve IntensityConfig class

The IntensityConfig class is well-structured with clear attribute definitions, comprehensive docstring, and appropriate use of validators for certain attributes. This ensures data integrity and provides clear documentation for users.

The implementation looks good and follows best practices.


102-135: Approve GeometricConfig class

The GeometricConfig class is well-structured with clear attribute definitions, comprehensive docstring, and appropriate default values for all attributes. This provides a good starting point for users and clear documentation of the available options.

The implementation looks good and follows best practices.


1-135: Final approval: Well-structured configuration classes with minor improvements needed

Overall, the implementation of the configuration classes in sleap_nn/config/data_config.py is well-done. The classes are clearly structured, with comprehensive docstrings and appropriate use of the attrs library. The suggested improvements mainly revolve around import statements and using forward references to prevent circular imports.

Key points:

  1. Remove the unused List import.
  2. Add the missing Union import.
  3. Use forward references for PreprocessingConfig, AugmentationConfig, IntensityConfig, and GeometricConfig to resolve undefined names and prevent potential circular imports.

After implementing these minor changes, the code will be more robust and maintainable.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


33-33: Undefined name PreprocessingConfig

(F821)


33-33: Undefined name PreprocessingConfig

(F821)


35-35: Undefined name AugmentationConfig

(F821)


54-54: Undefined name Union

(F821)


69-69: Undefined name IntensityConfig

(F821)


70-70: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/data_config.py (3)

1-3: Approve imports with minor suggestion

The imports look good overall. However, there's an unused import that can be removed.

Remove the unused List import:

-from typing import Optional, Tuple, List, Dict
+from typing import Optional, Tuple, Dict
🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


13-36: Approve DataConfig class and suggest forward references

The DataConfig class is well-structured with clear documentation. However, there are undefined names for PreprocessingConfig and AugmentationConfig.

To avoid potential circular imports, use forward references for these classes. Add the following import at the top of the file and modify the class attributes:

from __future__ import annotations

Then update the attributes:

-    preprocessing: PreprocessingConfig = attrs.field(factory=PreprocessingConfig)
-    augmentation_config: Optional[AugmentationConfig] = None
+    preprocessing: 'PreprocessingConfig' = attrs.field(factory=lambda: PreprocessingConfig())
+    augmentation_config: Optional['AugmentationConfig'] = None

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


60-72: Approve AugmentationConfig class and suggest forward references

The AugmentationConfig class is well-structured with clear documentation. However, there are undefined names for IntensityConfig and GeometricConfig.

To avoid potential circular imports, use forward references for these classes. Modify the class attributes as follows:

-    intensity: Optional[IntensityConfig] = attrs.field(default=None)
-    geometric: Optional[GeometricConfig] = attrs.field(default=None)
+    intensity: Optional['IntensityConfig'] = attrs.field(default=None)
+    geometric: Optional['GeometricConfig'] = attrs.field(default=None)

This change will resolve the undefined names issue and prevent potential circular imports.

🧰 Tools
🪛 Ruff

71-71: Undefined name IntensityConfig

(F821)


72-72: Undefined name GeometricConfig

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 7ea94b6 and 192399d.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/data_config.py

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


55-55: Undefined name Union

(F821)


71-71: Undefined name IntensityConfig

(F821)


72-72: Undefined name GeometricConfig

(F821)

🔇 Additional comments (3)
sleap_nn/config/data_config.py (3)

75-107: Approve IntensityConfig class

The IntensityConfig class is well-structured with clear documentation and appropriate use of attrs validators. All attributes have sensible default values, and the validators help ensure valid input.

The implementation looks good, and no changes are necessary.


110-143: Approve GeometricConfig class

The GeometricConfig class is well-structured with clear and detailed documentation for each attribute. The use of Optional types and default values is appropriate, providing flexibility in configuration.

The implementation looks good, and no changes are necessary.


1-143: Summary and final recommendation

Overall, the data_config.py file is well-structured and provides a comprehensive set of configuration classes for data handling in a machine learning pipeline. The use of the attrs library and detailed documentation for each class and attribute is commendable.

To improve the code quality and prevent potential issues, please implement the following changes:

  1. Remove the unused List import.
  2. Add from __future__ import annotations at the top of the file.
  3. Use forward references for PreprocessingConfig, AugmentationConfig, IntensityConfig, and GeometricConfig in the respective class attributes.
  4. Add from typing import Union to resolve the undefined Union type.

After implementing these changes, the file will be more robust and less prone to circular import issues.

The overall implementation is good, and with these minor adjustments, it will be excellent.

🧰 Tools
🪛 Ruff

3-3: typing.List imported but unused

Remove unused import: typing.List

(F401)


34-34: Undefined name PreprocessingConfig

(F821)


34-34: Undefined name PreprocessingConfig

(F821)


36-36: Undefined name AugmentationConfig

(F821)


55-55: Undefined name Union

(F821)


71-71: Undefined name IntensityConfig

(F821)


72-72: Undefined name GeometricConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/model_config.py (1)

1-9: Consider adding type hints to class attributes.

While the attributes are well-described in the docstring, adding type hints would improve code clarity and enable better static type checking. Consider updating the class definition to include type information for each attribute. For example:

@attr.s(auto_attribs=True)
class ModelConfig:
    backbone: dict
    heads: dict
    base_checkpoint: str

Replace dict and str with more specific types if applicable. This change will enhance code readability and make it easier for developers to understand the expected types of these configuration parameters.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 192399d and 184d5ea.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

1-1: Undefined name attr

(F821)

🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)

1-9: LGTM: Well-structured class definition and docstring.

The ModelConfig class is well-defined using the @attr.s(auto_attribs=True) decorator, which is appropriate for automatic attribute management. The docstring provides a clear description of the class purpose and its attributes, following the Google style guide format. This promotes good code documentation and maintainability.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


1-9: Overall, good implementation of the ModelConfig class.

The ModelConfig class aligns well with the PR objectives of migrating configuration classes and updating them to use the new attrs API. The implementation is clean and well-documented. With the suggested improvements (adding the import statement and considering type hints), this class will provide a solid foundation for model configuration in the project.

🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)


1-1: ⚠️ Potential issue

Add missing import statement for attr.

The attr module is used in the code but not imported. To resolve the undefined name issue and make the usage of the @attr.s decorator valid, please add the following import statement at the beginning of the file:

import attr
🧰 Tools
🪛 Ruff

1-1: Undefined name attr

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (1)
sleap_nn/config/model_config.py (1)

23-23: Consider using consistent decorators across the file.

The UNetConfig class uses @attrs.define, while the main ModelConfig class uses @attr.s. For consistency, consider using the same decorator style throughout the file. If you're using a newer version of attrs, you might want to update all decorators to @attrs.define. Otherwise, change this to @attr.s(auto_attribs=True) to match the main class.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 184d5ea and c79ea42.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

13-14: SyntaxError: Expected an expression


15-16: SyntaxError: Expected an expression

🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)

1-59: Overall structure looks good, with some minor improvements needed.

The ModelConfig class and its nested configuration classes provide a comprehensive and flexible structure for configuring different model architectures. The use of attrs for class definitions is a good choice for reducing boilerplate code.

To improve the file:

  1. Add the missing imports for attr and Enum.
  2. Complete the attribute definitions for pre_trained_weights and backbone_config.
  3. Consider using consistent decorators across all class definitions.

These changes will enhance the code's correctness and consistency.

🧰 Tools
🪛 Ruff

13-14: SyntaxError: Expected an expression


15-16: SyntaxError: Expected an expression


1-1: ⚠️ Potential issue

Add missing import for attr module.

The attr module is used in this file but not imported. Add the following import at the beginning of the file:

import attr

This will resolve the undefined name issue for attr.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (3)

20-46: LGTM! Consider refactoring error messages.

The methods in the ModelConfig class are well-implemented and provide good initialization and validation logic.

To reduce code duplication in error messages, consider creating a helper method for generating error messages:

def _get_weight_error_message(self, backbone_type, valid_weights):
    return f"Invalid pre-trained weights for {backbone_type}. Must be one of {valid_weights}"

# Then use it in validate_pre_trained_weights:
if self.backbone_type == BackboneType.CONVNEXT:
    if self.pre_trained_weights not in convnext_weights:
        raise ValueError(self._get_weight_error_message("ConvNext", convnext_weights))
elif self.backbone_type == BackboneType.SWINT:
    if self.pre_trained_weights not in swint_weights:
        raise ValueError(self._get_weight_error_message("SwinT", swint_weights))

This refactoring will make the code more maintainable and reduce the risk of inconsistencies in error messages.

🧰 Tools
🪛 Ruff

25-25: Undefined name BackboneType

(F821)


26-26: Undefined name UNetConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name ConvNextConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name SwinTConfig

(F821)


38-38: Undefined name BackboneType

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


53-65: LGTM! Consider adding validation for None values.

The UNetConfig class is well-defined with appropriate attributes for UNet configuration.

Consider adding validation for attributes that have None as default value, such as max_stride and stem_stride. You could do this in a post-init method:

@attrs.define
class UNetConfig:
    # ... existing attributes ...

    def __attrs_post_init__(self):
        if self.max_stride is None:
            # Set a default value or raise an error if it's required
            self.max_stride = 16  # Example default value
        if self.stem_stride is None:
            # Set a default value or raise an error if it's required
            self.stem_stride = 1  # Example default value

This ensures that these critical parameters always have valid values.


66-89: LGTM! Consider reordering attributes for consistency.

The ConvNextConfig and SwinTConfig classes are well-defined with appropriate attributes for their respective architectures.

For better consistency between the two classes, consider reordering the attributes in SwinTConfig to match the order in ConvNextConfig as closely as possible. For example:

@attrs.define
class SwinTConfig:
    model_type: str = "tiny"
    arch: dict = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]})
    in_channels: int = 1
    patch_size: list = attrs.field(factory=lambda: [4, 4])
    stem_patch_stride: int = 2
    window_size: list = attrs.field(factory=lambda: [7, 7])
    kernel_size: int = 3
    filters_rate: float = 1.5
    convs_per_block: int = 2
    up_interpolate: bool = True

This reordering makes it easier to compare the two configurations at a glance.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c79ea42 and d866e36.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


25-25: Undefined name BackboneType

(F821)


26-26: Undefined name UNetConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name ConvNextConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name SwinTConfig

(F821)


38-38: Undefined name BackboneType

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)

🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)

47-50: LGTM! Well-defined enum for backbone types.

The BackboneType enum is correctly defined and provides clear options for the backbone types. This approach enhances type safety and code readability.


1-89: Overall, excellent implementation with minor suggestions for improvement.

The ModelConfig class and its associated nested classes provide a comprehensive and well-structured configuration system for model architectures. The use of attrs for class definitions and Enum for backbone types demonstrates good coding practices.

Key strengths:

  1. Clear separation of concerns for different backbone types.
  2. Proper use of post-initialization and validation methods.
  3. Effective use of default values and factory functions for complex defaults.

Suggestions for improvement:

  1. Add import for Union from typing.
  2. Refactor error message generation in validate_pre_trained_weights.
  3. Add validation for None values in UNetConfig.
  4. Reorder attributes in SwinTConfig for consistency with ConvNextConfig.

These minor improvements will enhance the overall quality and maintainability of the code.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


25-25: Undefined name BackboneType

(F821)


26-26: Undefined name UNetConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name ConvNextConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name SwinTConfig

(F821)


38-38: Undefined name BackboneType

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


1-18: LGTM! Consider adding import for Union.

The imports and main class definition look good. The Enum import addresses a previous comment. However, there's a minor improvement we can make:

Consider adding the following import at the beginning of the file:

from typing import Union

This will resolve the undefined name issue for Union on line 18.

There might be a circular import issue causing some undefined names. Let's verify this:

If this script returns results, it indicates a potential circular import that needs to be addressed.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (2)
sleap_nn/config/model_config.py (2)

55-67: Consider adding type hints to UNetConfig attributes.

To improve code clarity and maintainability, consider adding type hints to the attributes in the UNetConfig class. For example:

@attrs.define
class UNetConfig:
    in_channels: int = 1
    kernel_size: int = 3
    filters: int = 32
    filters_rate: float = 1.5
    max_stride: Optional[int] = None
    stem_stride: Optional[int] = None
    middle_block: bool = True
    up_interpolate: bool = True
    stacks: int = 3
    convs_per_block: int = 2

Don't forget to import Optional from typing if you use it for max_stride and stem_stride.


68-91: Enhance ConvNextConfig and SwinTConfig with type hints and consider using attrs.Factory.

  1. Add type hints to improve code clarity:
@attrs.define
class ConvNextConfig:
    model_type: str = "tiny"
    arch: Dict[str, Union[List[int], List[int]]] = attrs.field(factory=lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]})
    stem_patch_kernel: int = 4
    stem_patch_stride: int = 2
    in_channels: int = 1
    kernel_size: int = 3
    filters_rate: float = 1.5
    convs_per_block: int = 2
    up_interpolate: bool = True

@attrs.define
class SwinTConfig:
    model_type: str = "tiny"
    arch: Dict[str, Union[int, List[int]]] = attrs.field(factory=lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]})
    patch_size: List[int] = attrs.field(factory=lambda: [4, 4])
    stem_patch_stride: int = 2
    window_size: List[int] = attrs.field(factory=lambda: [7, 7])
    in_channels: int = 1
    kernel_size: int = 3
    filters_rate: float = 1.5
    convs_per_block: int = 2
    up_interpolate: bool = True
  1. Consider using attrs.Factory instead of lambda functions for better readability:
from attrs import Factory

# In ConvNextConfig
arch: Dict[str, Union[List[int], List[int]]] = attrs.field(factory=Factory(lambda: {'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768]}))

# In SwinTConfig
arch: Dict[str, Union[int, List[int]]] = attrs.field(factory=Factory(lambda: {'embed': 96, 'depths': [2, 2, 6, 2], 'channels': [3, 6, 12, 24]}))
patch_size: List[int] = attrs.field(factory=Factory(lambda: [4, 4]))
window_size: List[int] = attrs.field(factory=Factory(lambda: [7, 7]))

These changes will improve type checking and make the code more explicit about the expected types for each attribute.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d866e36 and d572730.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name UNetConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name ConvNextConfig

(F821)


31-31: Undefined name BackboneType

(F821)


32-32: Undefined name SwinTConfig

(F821)


40-40: Undefined name BackboneType

(F821)


43-43: Undefined name BackboneType

(F821)


46-46: Undefined name BackboneType

(F821)

🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)

1-91: Overall assessment: Good implementation with room for minor improvements.

The ModelConfig class and its nested configuration classes provide a well-structured and comprehensive approach to managing model architecture configurations. The use of attrs for class definitions is a good choice, making the code more concise and less error-prone.

Key strengths:

  1. Clear separation of concerns for different backbone types.
  2. Thorough validation of pre-trained weights.
  3. Use of enums for backbone types, improving type safety.

Suggested improvements:

  1. Resolve import and undefined name issues.
  2. Add type hints throughout the file for better code clarity and maintainability.
  3. Minor refactoring of the validate_pre_trained_weights method to reduce code duplication.
  4. Consider using attrs.Factory for mutable default values.

These changes will enhance the overall quality of the code, making it more robust and easier to maintain in the future.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


27-27: Undefined name BackboneType

(F821)


28-28: Undefined name UNetConfig

(F821)


29-29: Undefined name BackboneType

(F821)


30-30: Undefined name ConvNextConfig

(F821)


31-31: Undefined name BackboneType

(F821)


32-32: Undefined name SwinTConfig

(F821)


40-40: Undefined name BackboneType

(F821)


43-43: Undefined name BackboneType

(F821)


46-46: Undefined name BackboneType

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (3)
sleap_nn/config/model_config.py (3)

1-2: Add missing imports for type hints.

To improve type checking and code clarity, please add the following imports at the beginning of the file:

from typing import Dict, List, Optional, Union

These imports are necessary for the type hints used throughout the file.


56-92: Enhance documentation and type hinting for nested configuration classes.

For the UNetConfig, ConvNextConfig, and SwinTConfig classes:

  1. Add docstrings to explain the purpose of each class and its attributes.
  2. Consider adding type hints to all attributes for better code clarity and type checking.

For example, for the UNetConfig class:

@attrs.define
class UNetConfig:
    """Configuration for UNet backbone.

    Attributes:
        in_channels (int): Number of input channels.
        kernel_size (int): Size of the convolutional kernel.
        filters (int): Number of filters in the first layer.
        filters_rate (float): Rate at which the number of filters increases.
        max_stride (Optional[int]): Maximum stride in the network.
        stem_stride (Optional[int]): Stride in the stem of the network.
        middle_block (bool): Whether to include a middle block.
        up_interpolate (bool): Whether to use interpolation for upsampling.
        stacks (int): Number of encoder/decoder stacks.
        convs_per_block (int): Number of convolutions per block.
    """
    in_channels: int = 1
    kernel_size: int = 3
    filters: int = 32
    filters_rate: float = 1.5
    max_stride: Optional[int] = None
    stem_stride: Optional[int] = None
    middle_block: bool = True
    up_interpolate: bool = True
    stacks: int = 3
    convs_per_block: int = 2

Apply similar improvements to ConvNextConfig and SwinTConfig classes.


94-120: Improve consistency and documentation for additional configuration classes.

For the HeadConfig, SingleInstanceConfig, and ConfMapsConfig classes:

  1. Add docstrings to HeadConfig and SingleInstanceConfig explaining their purpose and attributes, similar to ConfMapsConfig.
  2. Consider using a consistent style for optional attributes. For example, in ConfMapsConfig, you could use attrs.field(default=None) instead of Optional[Type] = None for consistency with other classes.

Example for HeadConfig:

@attrs.define
class HeadConfig:
    """Configuration for model heads.

    Attributes:
        head_configs (Dict[str, Optional[Dict]]): A dictionary of head configurations.
            Keys represent head types, and values are their respective configurations.
    """
    head_configs: Dict[str, Optional[Dict]] = attrs.field(
        factory=lambda: {
            "single_instance": None,
            "centroid": None,
            "centered_instance": None,
            "bottomup": None
        }
    )

Apply similar improvements to SingleInstanceConfig.

🧰 Tools
🪛 Ruff

96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


107-107: Undefined name Optional

(F821)


107-107: Undefined name ConfMapsConfig

(F821)


118-118: Undefined name Optional

(F821)


118-118: Undefined name List

(F821)


119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d572730 and 35b29c4.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


107-107: Undefined name Optional

(F821)


107-107: Undefined name ConfMapsConfig

(F821)


118-118: Undefined name Optional

(F821)


118-118: Undefined name List

(F821)


119-119: Undefined name Optional

(F821)


120-120: Undefined name Optional

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 35b29c4 and 5a5dc7b.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_nn/config/model_config.py

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)

🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)

50-92: LGTM: Nested configuration classes.

The nested configuration classes (BackboneType, UNetConfig, ConvNextConfig, and SwinTConfig) are well-structured and use appropriate attrs decorators. The use of attrs.field(factory=lambda: ...) for default dictionaries is a good practice to avoid mutable default arguments.


1-160: Overall assessment: Well-structured configuration system with minor improvements needed.

The model_config.py file introduces a comprehensive and well-structured configuration system for a machine learning model. It effectively uses the attrs library for class definitions and provides clear docstrings for attributes. The nested configuration classes for different backbone types and head configurations are well-organized.

To further improve the code:

  1. Add the missing imports for type hints.
  2. Update the backbone_type attribute in ModelConfig to use attrs.field.
  3. Consider simplifying the set_backbone_config and validate_pre_trained_weights methods as suggested.
  4. Fix the naming inconsistency in the centroid class.

These changes will enhance the code's clarity, maintainability, and consistency.

🧰 Tools
🪛 Ruff

17-17: Undefined name BackboneType

(F821)


17-17: Undefined name BackboneType

(F821)


18-18: Undefined name Union

(F821)


18-18: Undefined name UNetConfig

(F821)


18-18: Undefined name ConvNextConfig

(F821)


18-18: Undefined name SwinTConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


19-19: Undefined name HeadConfig

(F821)


28-28: Undefined name BackboneType

(F821)


29-29: Undefined name UNetConfig

(F821)


30-30: Undefined name BackboneType

(F821)


31-31: Undefined name ConvNextConfig

(F821)


32-32: Undefined name BackboneType

(F821)


33-33: Undefined name SwinTConfig

(F821)


41-41: Undefined name BackboneType

(F821)


44-44: Undefined name BackboneType

(F821)


47-47: Undefined name BackboneType

(F821)


96-96: Undefined name Dict

(F821)


96-96: Undefined name Optional

(F821)


96-96: Undefined name Dict

(F821)


108-108: Undefined name Optional

(F821)


108-108: Undefined name SingleInstanceConfMapsConfig

(F821)


119-119: Undefined name Optional

(F821)


119-119: Undefined name List

(F821)


120-120: Undefined name Optional

(F821)


121-121: Undefined name Optional

(F821)


126-126: Undefined name Optional

(F821)


126-126: Undefined name CentroidConfMapsConfig

(F821)


137-137: Undefined name Optional

(F821)


138-138: Undefined name Optional

(F821)


139-139: Undefined name Optional

(F821)


144-144: Undefined name Optional

(F821)


144-144: Undefined name CenteredInstanceConfMapsConfig

(F821)


156-156: Undefined name Optional

(F821)


156-156: Undefined name List

(F821)


157-157: Undefined name Optional

(F821)


158-158: Undefined name Optional

(F821)


159-159: Undefined name Optional

(F821)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🔭 Outside diff range comments (2)
tests/config/test_training_job_config.py (1)

171-171: 🛠️ Refactor suggestion

Add tests for error cases.

The test suite should include tests for error cases such as invalid file paths and malformed YAML.

def test_from_yaml_errors():
    """Test error handling in from_yaml."""
    with pytest.raises(ValueError, match="Failed to parse YAML"):
        TrainingJobConfig.from_yaml("invalid: yaml: :")

def test_load_yaml_errors():
    """Test error handling in load_yaml."""
    with pytest.raises(FileNotFoundError):
        TrainingJobConfig.load_yaml("nonexistent.yaml")

    with tempfile.NamedTemporaryFile(mode='w') as f:
        f.write("invalid: yaml: :")
        f.flush()
        with pytest.raises(ValueError, match="Failed to parse YAML"):
            TrainingJobConfig.load_yaml(f.name)
tests/config/test_model_config.py (1)

58-69: ⚠️ Potential issue

Fix incomplete test implementation.

The test creates a structured configuration but doesn't actually test anything. The config variable is unused.

def test_update_config(default_config):
    """Test updating configuration attributes."""
-    config = OmegaConf.structured(
-        ModelConfig(
-            backbone_type="unet",
-            init_weight="default",
-            pre_trained_weights=None,
-            backbone_config=BackboneConfig(),
-            head_configs=HeadConfig(),
-        )
-    )
+    # Test updating backbone_type
+    default_config.backbone_type = "convnext"
+    assert default_config.backbone_type == "convnext"
+
+    # Test updating pre_trained_weights
+    default_config.pre_trained_weights = "ConvNeXt_Tiny_Weights"
+    assert default_config.pre_trained_weights == "ConvNeXt_Tiny_Weights"
+
+    # Test invalid update
+    with pytest.raises(ValueError, match="Invalid backbone_type"):
+        default_config.backbone_type = "invalid"
🧰 Tools
🪛 Ruff (0.8.2)

60-60: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

♻️ Duplicate comments (4)
sleap_nn/config/training_job_config.py (2)

61-73: 🛠️ Refactor suggestion

Add error handling for YAML parsing and schema validation.

The method should handle potential errors during YAML parsing and schema validation.

 @classmethod
 def from_yaml(cls, yaml_data: Text) -> "TrainingJobConfig":
-    schema = OmegaConf.structured(cls)
-    config = OmegaConf.create(yaml_data)
-    OmegaConf.merge(schema, config)
-    return cls(**OmegaConf.to_container(config, resolve=True))
+    try:
+        schema = OmegaConf.structured(cls)
+        config = OmegaConf.create(yaml_data)
+        OmegaConf.merge(schema, config)
+        return cls(**OmegaConf.to_container(config, resolve=True))
+    except Exception as e:
+        raise ValueError(f"Failed to parse YAML data: {e}")

123-125: ⚠️ Potential issue

Fix parameter mismatch in load_config.

The load_yaml method doesn't accept the load_training_config parameter, which will cause a TypeError.

-    return TrainingJobConfig.load_yaml(
-        filename, load_training_config=load_training_config
-    )
+    if os.path.isdir(filename):
+        config_file = os.path.join(
+            filename,
+            "training_job.yaml" if load_training_config else "initial_config.yaml"
+        )
+        if not os.path.exists(config_file):
+            raise FileNotFoundError(
+                f"Configuration file not found in directory: {config_file}"
+            )
+        filename = config_file
+    return TrainingJobConfig.load_yaml(filename)
tests/config/test_model_config.py (2)

52-55: 🛠️ Refactor suggestion

Fix docstring and add specific error message assertion.

  1. The docstring is incorrectly copied from test_invalid_pre_trained_weights.
  2. Add specific error message assertion.
def test_invalid_backbonetype():
-    """Test validation failure with an invalid pre_trained_weights."""
+    """Test validation failure with an invalid backbone_type."""
-    with pytest.raises(ValueError):
+    with pytest.raises(ValueError, match='backbone_type must be one of "unet", "convnext", "swint"'):
        ModelConfig(backbone_type="net")

40-43: 🛠️ Refactor suggestion

Enhance test coverage and fix None comparison.

  1. Use is operator for None comparison.
  2. Test all attributes of the default configuration.
def test_default_initialization(default_config):
    """Test default initialization of ModelConfig."""
    assert default_config.init_weight == "default"
-    assert default_config.pre_trained_weights == None
+    assert default_config.pre_trained_weights is None
+    assert default_config.backbone_type == "unet"
+    assert isinstance(default_config.backbone_config, BackboneConfig)
+    assert isinstance(default_config.head_configs, HeadConfig)
🧰 Tools
🪛 Ruff (0.8.2)

43-43: Comparison to None should be cond is None

Replace with cond is None

(E711)

🧹 Nitpick comments (6)
sleap_nn/config/training_job_config.py (1)

27-35: Clean up unused imports.

Several imports are unused and can be removed to improve code clarity:

-import os
 from attrs import define, field, asdict
 import sleap_nn
 from sleap_nn.config.data_config import DataConfig
 from sleap_nn.config.model_config import ModelConfig
 from sleap_nn.config.trainer_config import TrainerConfig
-import yaml
-from typing import Text, Dict, Any, Optional
+from typing import Text, Optional
 from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff (0.8.2)

27-27: os imported but unused

Remove unused import: os

(F401)


28-28: attrs.field imported but unused

Remove unused import: attrs.field

(F401)


33-33: yaml imported but unused

Remove unused import: yaml

(F401)


34-34: typing.Dict imported but unused

Remove unused import

(F401)


34-34: typing.Any imported but unused

Remove unused import

(F401)

tests/config/test_model_config.py (3)

9-25: Clean up unused imports.

Several imported configuration classes are not used in the test file. Remove the following:

  • UNetConfig, ConvNextConfig, SwinTConfig
  • SingleInstanceConfig, CentroidConfig, CenteredInstanceConfig, BottomUpConfig
  • SingleInstanceConfMapsConfig, CentroidConfMapsConfig, CenteredInstanceConfMapsConfig, BottomUpConfMapsConfig, PAFConfig
from sleap_nn.config.model_config import (
    ModelConfig,
    BackboneConfig,
-    UNetConfig,
-    ConvNextConfig,
-    SwinTConfig,
    HeadConfig,
-    SingleInstanceConfig,
-    CentroidConfig,
-    CenteredInstanceConfig,
-    BottomUpConfig,
-    SingleInstanceConfMapsConfig,
-    CentroidConfMapsConfig,
-    CenteredInstanceConfMapsConfig,
-    BottomUpConfMapsConfig,
-    PAFConfig,
)
🧰 Tools
🪛 Ruff (0.8.2)

12-12: sleap_nn.config.model_config.UNetConfig imported but unused

Remove unused import

(F401)


13-13: sleap_nn.config.model_config.ConvNextConfig imported but unused

Remove unused import

(F401)


14-14: sleap_nn.config.model_config.SwinTConfig imported but unused

Remove unused import

(F401)


16-16: sleap_nn.config.model_config.SingleInstanceConfig imported but unused

Remove unused import

(F401)


17-17: sleap_nn.config.model_config.CentroidConfig imported but unused

Remove unused import

(F401)


18-18: sleap_nn.config.model_config.CenteredInstanceConfig imported but unused

Remove unused import

(F401)


19-19: sleap_nn.config.model_config.BottomUpConfig imported but unused

Remove unused import

(F401)


20-20: sleap_nn.config.model_config.SingleInstanceConfMapsConfig imported but unused

Remove unused import

(F401)


21-21: sleap_nn.config.model_config.CentroidConfMapsConfig imported but unused

Remove unused import

(F401)


22-22: sleap_nn.config.model_config.CenteredInstanceConfMapsConfig imported but unused

Remove unused import

(F401)


23-23: sleap_nn.config.model_config.BottomUpConfMapsConfig imported but unused

Remove unused import

(F401)


24-24: sleap_nn.config.model_config.PAFConfig imported but unused

Remove unused import

(F401)


28-37: Enhance fixture documentation.

The fixture's docstring could be more descriptive about the default values.

@pytest.fixture
def default_config():
-    """Fixture for a default ModelConfig instance."""
+    """Fixture for a default ModelConfig instance.
+    
+    Returns:
+        ModelConfig: A configuration instance with:
+            - init_weight: "default"
+            - pre_trained_weights: None
+            - backbone_type: "unet"
+            - backbone_config: Default BackboneConfig
+            - head_configs: Default HeadConfig
+    """

46-49: Add specific error message assertion.

The test should verify the specific error message for better debugging.

def test_invalid_pre_trained_weights():
    """Test validation failure with an invalid pre_trained_weights."""
-    with pytest.raises(ValueError):
+    with pytest.raises(ValueError, match="Invalid pre-trained weights for ConvNext"):
        ModelConfig(pre_trained_weights="here", backbone_type="unet")
sleap_nn/config/model_config.py (2)

269-272: Use Enum for backbone_type validation.

Consider using an Enum for backbone_type to improve type safety and maintainability.

+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"
+
-    backbone_type: str = field(
+    backbone_type: BackboneType = field(
-        default="unet",
+        default=BackboneType.UNET,
         validator=lambda instance, attr, value: instance.validate_backbone_type(value),
     )

283-334: Simplify validation methods.

  1. Use the proposed BackboneType enum.
  2. Move weight lists to class constants.
  3. Simplify validation using a mapping.
+    CONVNEXT_WEIGHTS = [
+        "ConvNeXt_Base_Weights",
+        "ConvNeXt_Tiny_Weights",
+        "ConvNeXt_Small_Weights",
+        "ConvNeXt_Large_Weights",
+    ]
+    SWINT_WEIGHTS = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
+    VALID_WEIGHTS = {
+        BackboneType.CONVNEXT: CONVNEXT_WEIGHTS,
+        BackboneType.SWINT: SWINT_WEIGHTS,
+        BackboneType.UNET: None,
+    }
+
     def validate_pre_trained_weights(self, value):
         if value is None:
             return
 
-        convnext_weights = [...]  # Remove duplicated lists
-        swint_weights = [...]
-
-        if self.backbone_type == "convnext":
-            if value not in convnext_weights:
-                raise ValueError(...)
-        elif self.backbone_type == "swint":
-            if value not in swint_weights:
-                raise ValueError(...)
-        elif self.backbone_type == "unet":
+        allowed_weights = self.VALID_WEIGHTS.get(self.backbone_type)
+        if allowed_weights is None and value is not None:
             raise ValueError("UNet does not support pre-trained weights.")
+        elif allowed_weights and value not in allowed_weights:
+            raise ValueError(
+                f"Invalid pre-trained weights for {self.backbone_type.value}. "
+                f"Must be one of {allowed_weights}"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f31f7fa and 52afbc7.

📒 Files selected for processing (4)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/training_job_config.py (1 hunks)
  • tests/config/test_model_config.py (1 hunks)
  • tests/config/test_training_job_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
tests/config/test_model_config.py

12-12: sleap_nn.config.model_config.UNetConfig imported but unused

Remove unused import

(F401)


13-13: sleap_nn.config.model_config.ConvNextConfig imported but unused

Remove unused import

(F401)


14-14: sleap_nn.config.model_config.SwinTConfig imported but unused

Remove unused import

(F401)


16-16: sleap_nn.config.model_config.SingleInstanceConfig imported but unused

Remove unused import

(F401)


17-17: sleap_nn.config.model_config.CentroidConfig imported but unused

Remove unused import

(F401)


18-18: sleap_nn.config.model_config.CenteredInstanceConfig imported but unused

Remove unused import

(F401)


19-19: sleap_nn.config.model_config.BottomUpConfig imported but unused

Remove unused import

(F401)


20-20: sleap_nn.config.model_config.SingleInstanceConfMapsConfig imported but unused

Remove unused import

(F401)


21-21: sleap_nn.config.model_config.CentroidConfMapsConfig imported but unused

Remove unused import

(F401)


22-22: sleap_nn.config.model_config.CenteredInstanceConfMapsConfig imported but unused

Remove unused import

(F401)


23-23: sleap_nn.config.model_config.BottomUpConfMapsConfig imported but unused

Remove unused import

(F401)


24-24: sleap_nn.config.model_config.PAFConfig imported but unused

Remove unused import

(F401)


43-43: Comparison to None should be cond is None

Replace with cond is None

(E711)


60-60: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

sleap_nn/config/model_config.py

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)

sleap_nn/config/training_job_config.py

27-27: os imported but unused

Remove unused import: os

(F401)


28-28: attrs.field imported but unused

Remove unused import: attrs.field

(F401)


33-33: yaml imported but unused

Remove unused import: yaml

(F401)


34-34: typing.Dict imported but unused

Remove unused import

(F401)


34-34: typing.Any imported but unused

Remove unused import

(F401)

tests/config/test_training_job_config.py

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
🔇 Additional comments (1)
sleap_nn/config/model_config.py (1)

241-255: LGTM!

The BackboneConfig class is well-structured with clear documentation and appropriate use of the @OneOf decorator.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (4)
sleap_nn/config/model_config.py (2)

326-331: 🛠️ Refactor suggestion

Fix inconsistent docstring.

The docstring mentions configurations that don't match the actual attributes:

  • multi_instance, multi_class_bottomup, and multi_class_topdown are mentioned but not defined
  • bottomup is defined but not mentioned
     Attributes:
         single_instance: An instance of `SingleInstanceConfmapsHeadConfig`.
         centroid: An instance of `CentroidsHeadConfig`.
         centered_instance: An instance of `CenteredInstanceConfmapsHeadConfig`.
-        multi_instance: An instance of `MultiInstanceConfig`.
-        multi_class_bottomup: An instance of `MultiClassBottomUpConfig`.
-        multi_class_topdown: An instance of `MultiClassTopDownConfig`.
+        bottomup: An instance of `BottomUpConfig`.

43-43: 🛠️ Refactor suggestion

Set default value for max_stride.

According to past review comments, max_stride shouldn't be optional and should have a default value of 16.

-    max_stride: Optional[int] = None
+    max_stride: int = 16
sleap_nn/config/training_job_config.py (1)

114-127: ⚠️ Potential issue

Fix load_config function implementation.

The function has incorrect parameter passing and missing file validation.

 def load_config(filename: Text, load_training_config: bool = True) -> TrainingJobConfig:
+    if not os.path.exists(filename):
+        raise FileNotFoundError(f"Configuration file not found: {filename}")
+
+    if os.path.isdir(filename):
+        config_file = os.path.join(
+            filename,
+            "training_job.yaml" if load_training_config else "initial_config.yaml"
+        )
+        if not os.path.exists(config_file):
+            raise FileNotFoundError(
+                f"Configuration file not found in directory: {config_file}"
+            )
+        filename = config_file
+
     return TrainingJobConfig.load_yaml(
-        filename, load_training_config=load_training_config
+        filename
     )
tests/config/test_training_job_config.py (1)

84-90: ⚠️ Potential issue

Fix assertions in test_from_yaml.

The assertions incorrectly treat config attributes as dictionaries instead of proper config objects.

-    assert config.data["train_labels_path"] == sample_config["data"].train_labels_path
-    assert config.data["val_labels_path"] == sample_config["data"].val_labels_path
-    assert config.model["backbone_type"] == sample_config["model"].backbone_type
-    assert (
-        config.trainer["early_stopping"]["patience"]
-        == sample_config["trainer"].early_stopping.patience
-    )
+    assert isinstance(config.data, DataConfig)
+    assert config.data.train_labels_path == sample_config["data"].train_labels_path
+    assert config.data.val_labels_path == sample_config["data"].val_labels_path
+    assert isinstance(config.model, ModelConfig)
+    assert config.model.backbone_type == sample_config["model"].backbone_type
+    assert isinstance(config.trainer, TrainerConfig)
+    assert config.trainer.early_stopping.patience == sample_config["trainer"].early_stopping.patience
🧹 Nitpick comments (9)
sleap_nn/config/model_config.py (6)

7-10: Remove unused import.

The Enum import is not used directly in the code.

-from enum import Enum
🧰 Tools
🪛 Ruff (0.8.2)

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)


77-77: Add validation for model_type.

Add validation to ensure model_type is one of the allowed values.

-    model_type: str = "tiny"  # Options: tiny, small, base, large
+    model_type: str = field(
+        default="tiny",
+        validator=attrs.validators.in_(["tiny", "small", "base", "large"])
+    )

114-114: Add validation for model_type.

Add validation to ensure model_type is one of the allowed values.

-    model_type: str = "tiny"  # Options: tiny, small, base
+    model_type: str = field(
+        default="tiny",
+        validator=attrs.validators.in_(["tiny", "small", "base"])
+    )

122-124: Improve type hints for list attributes.

Add type hints for list attributes to specify their element types.

-    patch_size: list = field(factory=lambda: [4, 4])
-    window_size: list = field(factory=lambda: [7, 7])
+    patch_size: List[int] = field(factory=lambda: [4, 4])
+    window_size: List[int] = field(factory=lambda: [7, 7])

376-379: Use enum for backbone_type validation.

Consider using an enum for backbone_type to make the code more maintainable and type-safe.

+class BackboneType(str, Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"

-    backbone_type: str = field(
+    backbone_type: BackboneType = field(
-        default="unet",
+        default=BackboneType.UNET,
         validator=lambda instance, attr, value: instance.validate_backbone_type(value),
     )

     def validate_backbone_type(self, value):
-        valid_types = ["unet", "convnext", "swint"]
-        if value not in valid_types:
-            raise ValueError(f"Invalid backbone_type. Must be one of {valid_types}")
+        if not isinstance(value, BackboneType):
+            raise ValueError(f"Invalid backbone_type. Must be a BackboneType enum.")

Also applies to: 390-398


399-441: Simplify pre_trained_weights validation using a dictionary.

The validation logic can be simplified and made more maintainable using a dictionary mapping.

     def validate_pre_trained_weights(self, value):
         if value is None:
             return

-        convnext_weights = [
-            "ConvNeXt_Base_Weights",
-            "ConvNeXt_Tiny_Weights",
-            "ConvNeXt_Small_Weights",
-            "ConvNeXt_Large_Weights",
-        ]
-        swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
-        if self.backbone_type == "convnext":
-            if value not in convnext_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
-                )
-        elif self.backbone_type == "swint":
-            if value not in swint_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
-                )
-        elif self.backbone_type == "unet":
-            raise ValueError("UNet does not support pre-trained weights.")
+        VALID_WEIGHTS = {
+            BackboneType.CONVNEXT: [
+                "ConvNeXt_Base_Weights",
+                "ConvNeXt_Tiny_Weights",
+                "ConvNeXt_Small_Weights",
+                "ConvNeXt_Large_Weights",
+            ],
+            BackboneType.SWINT: [
+                "Swin_T_Weights",
+                "Swin_S_Weights",
+                "Swin_B_Weights"
+            ],
+            BackboneType.UNET: None
+        }
+        
+        allowed_weights = VALID_WEIGHTS.get(self.backbone_type)
+        if allowed_weights is None:
+            raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+        elif value not in allowed_weights:
+            raise ValueError(
+                f"Invalid pre-trained weights for {self.backbone_type.value}. "
+                f"Must be one of {allowed_weights}"
+            )
sleap_nn/config/training_job_config.py (1)

27-35: Clean up imports.

Several imports are unused and can be removed:

  • os (used in load_config but not implemented correctly)
  • field from attrs (not used)
  • yaml (replaced by OmegaConf)
  • Dict, Any from typing (not used)
-import os
-from attrs import define, field, asdict
+from attrs import define, asdict
 import sleap_nn
 from sleap_nn.config.data_config import DataConfig
 from sleap_nn.config.model_config import ModelConfig
 from sleap_nn.config.trainer_config import TrainerConfig
-import yaml
-from typing import Text, Dict, Any, Optional
+from typing import Text, Optional
 from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff (0.8.2)

27-27: os imported but unused

Remove unused import: os

(F401)


28-28: attrs.field imported but unused

Remove unused import: attrs.field

(F401)


33-33: yaml imported but unused

Remove unused import: yaml

(F401)


34-34: typing.Dict imported but unused

Remove unused import

(F401)


34-34: typing.Any imported but unused

Remove unused import

(F401)

tests/config/test_training_job_config.py (2)

35-35: Remove unused import.

The asdict import from dataclasses is not used in the test file.

-from dataclasses import asdict
🧰 Tools
🪛 Ruff (0.8.2)

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)


194-194: Remove unused variable assignment.

The config variable is assigned but never used.

-        config = TrainingJobConfig.from_yaml(yaml_data)
+        TrainingJobConfig.from_yaml(yaml_data)
🧰 Tools
🪛 Ruff (0.8.2)

194-194: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 52afbc7 and d6c87af.

📒 Files selected for processing (3)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/training_job_config.py (1 hunks)
  • tests/config/test_training_job_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
sleap_nn/config/model_config.py

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)

sleap_nn/config/training_job_config.py

27-27: os imported but unused

Remove unused import: os

(F401)


28-28: attrs.field imported but unused

Remove unused import: attrs.field

(F401)


33-33: yaml imported but unused

Remove unused import: yaml

(F401)


34-34: typing.Dict imported but unused

Remove unused import

(F401)


34-34: typing.Any imported but unused

Remove unused import

(F401)

tests/config/test_training_job_config.py

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)


194-194: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
🔇 Additional comments (2)
sleap_nn/config/model_config.py (2)

1-5: LGTM!

The file-level docstring clearly explains the purpose of these configuration classes.


156-157: Set default values for sigma and output_stride.

According to past review comments, sigma and output_stride shouldn't be optional. Please check the original implementation for default values.

Run the following script to find the default values:

❌ Verification inconclusive

🏁 Script executed:

#!/bin/bash
# Search for default values in the original implementation
rg -A 5 "sigma.*=|output_stride.*=" sleap/nn/config/model.py

Length of output: 127


Missing default values for sigma and output_stride: Manual verification required

  • The search for default values in the file sleap/nn/config/model.py produced no output, and the file wasn’t found.
  • Without an existing reference for default values, please verify manually what the defaults should be (or if they should be removed from the Optional hint) and update sleap_nn/config/model_config.py accordingly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
sleap_nn/config/model_config.py (3)

380-380: 🛠️ Refactor suggestion

Rename init_weight to init_weights.

For consistency with the codebase, rename the attribute.

-    init_weight: str = "default"
+    init_weights: str = "default"

318-338: 🛠️ Refactor suggestion

Fix inconsistent docstring in HeadConfig.

The docstring mentions configurations that don't match the actual attributes:

  • multi_instance, multi_class_bottomup, and multi_class_topdown are mentioned but not defined
  • bottomup is defined but not mentioned
     Attributes:
         single_instance: An instance of `SingleInstanceConfmapsHeadConfig`.
         centroid: An instance of `CentroidsHeadConfig`.
         centered_instance: An instance of `CenteredInstanceConfmapsHeadConfig`.
-        multi_instance: An instance of `MultiInstanceConfig`.
-        multi_class_bottomup: An instance of `MultiClassBottomUpConfig`.
-        multi_class_topdown: An instance of `MultiClassTopDownConfig`.
+        bottomup: An instance of `BottomUpConfig`.

51-88: 🛠️ Refactor suggestion

Add validation for model_type attribute.

The model_type options are only documented in a comment. Add validation to enforce these options.

-    model_type: str = "tiny"  # Options: tiny, small, base, large
+    model_type: str = field(
+        default="tiny",
+        validator=attrs.validators.in_(["tiny", "small", "base", "large"])
+    )
🧹 Nitpick comments (1)
sleap_nn/config/model_config.py (1)

399-441: Simplify pre-trained weights validation.

Use a dictionary mapping to make the validation more maintainable.

     def validate_pre_trained_weights(self, value):
         if value is None:
             return
 
-        convnext_weights = [
-            "ConvNeXt_Base_Weights",
-            "ConvNeXt_Tiny_Weights",
-            "ConvNeXt_Small_Weights",
-            "ConvNeXt_Large_Weights",
-        ]
-        swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
-        if self.backbone_type == "convnext":
-            if value not in convnext_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
-                )
-        elif self.backbone_type == "swint":
-            if value not in swint_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
-                )
-        elif self.backbone_type == "unet":
-            raise ValueError("UNet does not support pre-trained weights.")
+        VALID_WEIGHTS = {
+            BackboneType.CONVNEXT: [
+                "ConvNeXt_Base_Weights",
+                "ConvNeXt_Tiny_Weights",
+                "ConvNeXt_Small_Weights",
+                "ConvNeXt_Large_Weights",
+            ],
+            BackboneType.SWINT: [
+                "Swin_T_Weights",
+                "Swin_S_Weights",
+                "Swin_B_Weights"
+            ],
+            BackboneType.UNET: None
+        }
+        
+        allowed_weights = VALID_WEIGHTS.get(self.backbone_type)
+        if allowed_weights is None:
+            raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+        elif value not in allowed_weights:
+            raise ValueError(
+                f"Invalid pre-trained weights for {self.backbone_type.value}. "
+                f"Must be one of {allowed_weights}"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d6c87af and dfde616.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
sleap_nn/config/model_config.py

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)

⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Lint
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
🔇 Additional comments (3)
sleap_nn/config/model_config.py (3)

1-5: LGTM! Clear and informative docstring.

The docstring effectively communicates the purpose of these configuration classes.


7-10: Keep the Enum import.

While static analysis flags Enum as unused, it's required for class inheritance in the codebase. The import is correctly placed and necessary.

🧰 Tools
🪛 Ruff (0.8.2)

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)


14-49: LGTM! Well-structured UNetConfig class.

The class has comprehensive docstrings and appropriate default values for all attributes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (8)
sleap_nn/config/training_job_config.py (5)

61-74: 🛠️ Refactor suggestion

Complete schema validation and add error handling.

The schema validation is incomplete and error handling is missing.

Apply this diff:

     @classmethod
     def from_yaml(cls, yaml_data: Text) -> "TrainingJobConfig":
-        schema = OmegaConf.structured(cls)
-        config = OmegaConf.create(yaml_data)
-        config = OmegaConf.merge(schema, config)
-        OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
-        return config
+        try:
+            schema = OmegaConf.structured(cls)
+            config = OmegaConf.create(yaml_data)
+            config = OmegaConf.merge(schema, config)
+            config_dict = OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
+            return cls(**config_dict)
+        except Exception as e:
+            raise ValueError(f"Failed to parse YAML data: {e}")

93-111: 🛠️ Refactor suggestion

Fix return type hint and add error handling.

The method's return type hint is incorrect, and error handling is missing.

Apply this diff:

-    def to_yaml(self, filename: Optional[Text] = None) -> None:
+    def to_yaml(self, filename: Optional[Text] = None) -> str:
         """Serialize and optionally save the configuration to YAML format.
 
         Args:
             filename: Optional path to save the YAML file to. If not provided,
                      the configuration will only be converted to YAML format.
+        Returns:
+            The YAML-encoded string representation of the configuration.
+        Raises:
+            ValueError: If the configuration cannot be serialized or saved.
         """
-        # Convert attrs objects to nested dictionaries
-        config_dict = asdict(self)
+        try:
+            # Convert attrs objects to nested dictionaries
+            config_dict = asdict(self)
 
-        # Handle any special cases (like enums) that need manual conversion
-        if config_dict.get("model", {}).get("backbone_type"):
-            config_dict["model"]["backbone_type"] = self.model.backbone_type
+            # Handle any special cases (like enums) that need manual conversion
+            if config_dict.get("model", {}).get("backbone_type"):
+                config_dict["model"]["backbone_type"] = self.model.backbone_type
 
-        # Create OmegaConf object and save if filename provided
-        conf = OmegaConf.create(config_dict)
-        if filename is not None:
-            OmegaConf.save(conf, filename)
-        return
+            # Create OmegaConf object and save if filename provided
+            conf = OmegaConf.create(config_dict)
+            yaml_str = OmegaConf.to_yaml(conf)
+            if filename is not None:
+                with open(filename, "w") as f:
+                    f.write(yaml_str)
+            return yaml_str
+        except Exception as e:
+            raise ValueError(f"Failed to serialize or save configuration: {e}")

77-91: 🛠️ Refactor suggestion

Complete schema validation and add error handling.

Similar to from_yaml, the schema validation is incomplete and error handling is missing.

Apply this diff:

     @classmethod
     def load_yaml(cls, filename: Text) -> "TrainingJobConfig":
-        schema = OmegaConf.structured(cls)
-        config = OmegaConf.load(filename)
-        OmegaConf.merge(schema, config)
-        OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
-        return config
+        try:
+            schema = OmegaConf.structured(cls)
+            config = OmegaConf.load(filename)
+            config = OmegaConf.merge(schema, config)
+            config_dict = OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
+            return cls(**config_dict)
+        except Exception as e:
+            raise ValueError(f"Failed to load YAML file {filename}: {e}")

52-54: 🛠️ Refactor suggestion

Update attribute names to match documentation.

The attribute names should be consistent with the documentation in docs/config.md.

Apply this diff:

-    data_config: DataConfig = DataConfig()
-    model_config: ModelConfig = ModelConfig()
-    trainer_config: TrainerConfig = TrainerConfig()
+    data: DataConfig = DataConfig()
+    model: ModelConfig = ModelConfig()
+    trainer: TrainerConfig = TrainerConfig()

114-127: 🛠️ Refactor suggestion

Add validation for configuration loading.

The function should validate the existence of the file and handle directory paths correctly.

Apply this diff:

 def load_config(filename: Text, load_training_config: bool = True) -> TrainingJobConfig:
+    if not os.path.exists(filename):
+        raise FileNotFoundError(f"Configuration file not found: {filename}")
+
+    if os.path.isdir(filename):
+        config_file = os.path.join(
+            filename,
+            "training_job.yaml" if load_training_config else "initial_config.yaml"
+        )
+        if not os.path.exists(config_file):
+            raise FileNotFoundError(
+                f"Configuration file not found in directory: {config_file}"
+            )
+        filename = config_file
+
     return TrainingJobConfig.load_yaml(
-        filename, load_training_config=load_training_config
+        filename
     )
sleap_nn/config/model_config.py (3)

43-43: 🛠️ Refactor suggestion

Set default value for max_stride.

According to past review comments, max_stride should have a default value of 16.

Apply this diff:

-    max_stride: int = 16
+    max_stride: int = 16  # This value is required and cannot be None

374-396: 🛠️ Refactor suggestion

Convert backbone_type to an enum.

Replace string-based validation with an enum for better type safety and maintainability.

Apply this diff:

+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"
+
-    backbone_type: str = field(
+    backbone_type: BackboneType = field(
-        default="unet",
+        default=BackboneType.UNET,
         validator=lambda instance, attr, value: instance.validate_backbone_type(value),
     )

     def validate_backbone_type(self, value):
-        valid_types = ["unet", "convnext", "swint"]
-        if value not in valid_types:
-            raise ValueError(f"Invalid backbone_type. Must be one of {valid_types}")
+        if not isinstance(value, BackboneType):
+            raise ValueError(f"Invalid backbone_type. Must be a BackboneType enum.")

378-378: 🛠️ Refactor suggestion

Update attribute name to match documentation.

The attribute name should be init_weights instead of init_weight.

Apply this diff:

-    init_weights: str = "default"
+    init_weight: str = "default"
🧹 Nitpick comments (6)
tests/config/test_training_job_config.py (4)

35-35: Remove unused import.

The asdict import is not used in the code.

Apply this diff:

-from dataclasses import asdict
🧰 Tools
🪛 Ruff (0.8.2)

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)


82-98: Add type assertions for configuration objects.

The test should verify that the configuration objects are of the correct type.

Apply this diff:

     assert config.name == sample_config["name"]
     assert config.description == sample_config["description"]
+    assert isinstance(config.data_config, DataConfig)
     assert (
         config.data_config.train_labels_path
         == sample_config["data_config"].train_labels_path
     )
     assert (
         config.data_config.val_labels_path
         == sample_config["data_config"].val_labels_path
     )
+    assert isinstance(config.model_config, ModelConfig)
     assert (
         config.model_config.backbone_type == sample_config["model_config"].backbone_type
     )
+    assert isinstance(config.trainer_config, TrainerConfig)
     assert (
         config.trainer_config.early_stopping.patience
         == sample_config["trainer_config"].early_stopping.patience
     )

219-220: Remove unused variable assignment.

The config variable is not used in the test.

Apply this diff:

     with pytest.raises(MissingMandatoryValue):
-        config = TrainingJobConfig.from_yaml(yaml_data)
+        TrainingJobConfig.from_yaml(yaml_data)
🧰 Tools
🪛 Ruff (0.8.2)

220-220: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)


221-221: Add error case tests.

The test suite should include tests for error cases such as invalid file paths and malformed YAML.

Add the following test function:

def test_load_yaml_errors():
    """Test error handling in load_yaml."""
    with pytest.raises(FileNotFoundError):
        TrainingJobConfig.load_yaml("nonexistent.yaml")

    with tempfile.NamedTemporaryFile(mode='w') as f:
        f.write("invalid: yaml: :")
        f.flush()
        with pytest.raises(ValueError, match="Failed to parse YAML"):
            TrainingJobConfig.load_yaml(f.name)
sleap_nn/config/model_config.py (2)

8-8: Remove unused import.

The Enum import is not used in the code.

Apply this diff:

-from enum import Enum
🧰 Tools
🪛 Ruff (0.8.2)

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)


419-438: Simplify pre-trained weights validation using a mapping.

The validation logic can be simplified and made more maintainable using a dictionary mapping.

Apply this diff:

-        convnext_weights = [
-            "ConvNeXt_Base_Weights",
-            "ConvNeXt_Tiny_Weights",
-            "ConvNeXt_Small_Weights",
-            "ConvNeXt_Large_Weights",
-        ]
-        swint_weights = ["Swin_T_Weights", "Swin_S_Weights", "Swin_B_Weights"]
-
-        if self.backbone_type == "convnext":
-            if value not in convnext_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
-                )
-        elif self.backbone_type == "swint":
-            if value not in swint_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
-                )
-        elif self.backbone_type == "unet":
-            raise ValueError("UNet does not support pre-trained weights.")
+        VALID_WEIGHTS = {
+            BackboneType.CONVNEXT: [
+                "ConvNeXt_Base_Weights",
+                "ConvNeXt_Tiny_Weights",
+                "ConvNeXt_Small_Weights",
+                "ConvNeXt_Large_Weights",
+            ],
+            BackboneType.SWINT: [
+                "Swin_T_Weights",
+                "Swin_S_Weights",
+                "Swin_B_Weights"
+            ],
+            BackboneType.UNET: None
+        }
+        
+        allowed_weights = VALID_WEIGHTS.get(self.backbone_type)
+        if allowed_weights is None and value is not None:
+            raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+        elif allowed_weights and value not in allowed_weights:
+            raise ValueError(
+                f"Invalid pre-trained weights for {self.backbone_type.value}. "
+                f"Must be one of {allowed_weights}"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dfde616 and 942a4ff.

📒 Files selected for processing (3)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/training_job_config.py (1 hunks)
  • tests/config/test_training_job_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
sleap_nn/config/training_job_config.py

27-27: os imported but unused

Remove unused import: os

(F401)


28-28: attrs.field imported but unused

Remove unused import: attrs.field

(F401)


33-33: yaml imported but unused

Remove unused import: yaml

(F401)


34-34: typing.Dict imported but unused

Remove unused import

(F401)


34-34: typing.Any imported but unused

Remove unused import

(F401)

sleap_nn/config/model_config.py

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)

tests/config/test_training_job_config.py

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)


220-220: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
  • GitHub Check: Lint

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
sleap_nn/config/model_config.py (1)

387-390: 🛠️ Refactor suggestion

Convert backbone_type to an enum for better type safety.

Replace string-based validation with an enum for better type safety and maintainability.

+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"
+
 @define
 class ModelConfig:
-    backbone_type: str = field(
+    backbone_type: BackboneType = field(
-        default="unet",
+        default=BackboneType.UNET,
         validator=lambda instance, attr, value: instance.validate_backbone_type(value),
     )

     def validate_backbone_type(self, value):
-        valid_types = ["unet", "convnext", "swint"]
-        if value not in valid_types:
-            raise ValueError(f"Invalid backbone_type. Must be one of {valid_types}")
+        if not isinstance(value, BackboneType):
+            raise ValueError(f"Invalid backbone_type. Must be a BackboneType enum.")
🧹 Nitpick comments (3)
sleap_nn/config/model_config.py (3)

77-77: Add validation for model_type in ConvNextConfig.

Use attrs validator to enforce valid model types.

-    model_type: str = "tiny"  # Options: tiny, small, base, large
+    model_type: str = field(
+        default="tiny",
+        validator=attrs.validators.in_(["tiny", "small", "base", "large"])
+    )

146-171: Create a base configuration class for common attributes.

The configuration classes share common attributes (sigma, output_stride). Consider creating a base class to reduce code duplication.

@attrs.define
class BaseConfMapsConfig:
    """Base configuration for confidence maps.

    Attributes:
        sigma: Spread of the Gaussian distribution of the confidence maps.
        output_stride: The stride of the output confidence maps relative to the input image.
        loss_weight: Weight of the loss term during training.
    """
    sigma: float = 5.0
    output_stride: int = 1
    loss_weight: Optional[float] = None

@attrs.define
class SingleInstanceConfMapsConfig(BaseConfMapsConfig):
    """Single Instance configuration map."""
    part_names: Optional[List[str]] = None

# Similar changes for other configuration classes

Also applies to: 174-201, 204-237, 240-269, 272-300


440-451: Improve pre-trained weights validation.

Use the BackboneType enum and a mapping for cleaner validation.

-        if self.backbone_type == "convnext":
-            if value not in convnext_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
-                )
-        elif self.backbone_type == "swint":
-            if value not in swint_weights:
-                raise ValueError(
-                    f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
-                )
-        elif self.backbone_type == "unet":
-            raise ValueError("UNet does not support pre-trained weights.")
+        VALID_WEIGHTS = {
+            BackboneType.CONVNEXT: convnext_weights,
+            BackboneType.SWINT: swint_weights,
+            BackboneType.UNET: None
+        }
+        allowed_weights = VALID_WEIGHTS[self.backbone_type]
+        if allowed_weights is None and value is not None:
+            raise ValueError(f"{self.backbone_type.value} does not support pre-trained weights.")
+        elif allowed_weights and value not in allowed_weights:
+            raise ValueError(
+                f"Invalid pre-trained weights for {self.backbone_type.value}. "
+                f"Must be one of {allowed_weights}"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 942a4ff and 60ac298.

📒 Files selected for processing (1)
  • sleap_nn/config/model_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
sleap_nn/config/model_config.py

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🔭 Outside diff range comments (2)
tests/config/test_training_job_config.py (1)

199-221: 🛠️ Refactor suggestion

Fix unused variable and improve error case testing.

The test has an unused variable and could be expanded to cover more error cases.

 def test_missing_attributes(sample_config):
     """Test creating a TrainingJobConfig from a valid YAML string."""
     config_dict = {
         "name": sample_config["name"],
         "description": sample_config["description"],
         "data_config": {
             "provider": sample_config["data_config"].provider,
         },
         "model_config": {
             "backbone_type": sample_config["model_config"].backbone_type,
             "init_weights": sample_config["model_config"].init_weights,
         },
         "trainer_config": {
             "early_stopping": {
                 "patience": sample_config["trainer_config"].early_stopping.patience,
             },
         },
     }
     yaml_data = OmegaConf.to_yaml(config_dict)
 
     with pytest.raises(MissingMandatoryValue):
-        config = TrainingJobConfig.from_yaml(yaml_data)
+        TrainingJobConfig.from_yaml(yaml_data)

+def test_load_yaml_errors():
+    """Test error handling in load_yaml."""
+    with pytest.raises(FileNotFoundError):
+        TrainingJobConfig.load_yaml("nonexistent.yaml")
+
+    with tempfile.NamedTemporaryFile(mode='w') as f:
+        f.write("invalid: yaml: :")
+        f.flush()
+        with pytest.raises(ValueError, match="Failed to parse YAML"):
+            TrainingJobConfig.load_yaml(f.name)
🧰 Tools
🪛 Ruff (0.8.2)

220-220: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

sleap_nn/config/model_config.py (1)

439-451: 🛠️ Refactor suggestion

Use BackboneType enum in validation.

Replace string literals with BackboneType enum values.

-        if self.backbone_type == "convnext":
+        if self.backbone_type == BackboneType.CONVNEXT:
             if value not in convnext_weights:
                 raise ValueError(
                     f"Invalid pre-trained weights for ConvNext. Must be one of {convnext_weights}"
                 )
-        elif self.backbone_type == "swint":
+        elif self.backbone_type == BackboneType.SWINT:
             if value not in swint_weights:
                 raise ValueError(
                     f"Invalid pre-trained weights for SwinT. Must be one of {swint_weights}"
                 )
-        elif self.backbone_type == "unet":
+        elif self.backbone_type == BackboneType.UNET:
             raise ValueError("UNet does not support pre-trained weights.")
♻️ Duplicate comments (3)
tests/config/test_model_config.py (2)

53-54: 🛠️ Refactor suggestion

Fix incorrect docstring.

The docstring is incorrectly copied from test_invalid_pre_trained_weights.

-    """Test validation failure with an invalid pre_trained_weights."""
+    """Test validation failure with an invalid backbone_type."""

43-43: 🛠️ Refactor suggestion

Fix None comparison.

Use the is operator for None comparison.

-    assert default_config.pre_trained_weights == None
+    assert default_config.pre_trained_weights is None
🧰 Tools
🪛 Ruff (0.8.2)

43-43: Comparison to None should be cond is None

Replace with cond is None

(E711)

sleap_nn/config/model_config.py (1)

386-389: 🛠️ Refactor suggestion

Convert backbone_type to an enum.

Replace string-based validation with an enum for better type safety and maintainability.

+class BackboneType(Enum):
+    UNET = "unet"
+    CONVNEXT = "convnext"
+    SWINT = "swint"
+
-    backbone_type: str = field(
+    backbone_type: BackboneType = field(
-        default="unet",
+        default=BackboneType.UNET,
         validator=lambda instance, attr, value: instance.validate_backbone_type(value),
     )
🧹 Nitpick comments (2)
tests/config/test_training_job_config.py (2)

27-36: Remove unused import.

The asdict import from dataclasses is not used in this file.

-from dataclasses import asdict
🧰 Tools
🪛 Ruff (0.8.2)

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)


59-98: Simplify assertions using proper config objects.

The test could be more concise by directly comparing config objects instead of individual fields.

-    assert config.name == sample_config["name"]
-    assert config.description == sample_config["description"]
-    assert (
-        config.data_config.train_labels_path
-        == sample_config["data_config"].train_labels_path
-    )
-    assert (
-        config.data_config.val_labels_path
-        == sample_config["data_config"].val_labels_path
-    )
-    assert (
-        config.model_config.backbone_type == sample_config["model_config"].backbone_type
-    )
-    assert (
-        config.trainer_config.early_stopping.patience
-        == sample_config["trainer_config"].early_stopping.patience
-    )
+    assert config.name == sample_config["name"]
+    assert config.description == sample_config["description"]
+    assert config.data_config == sample_config["data_config"]
+    assert config.model_config == sample_config["model_config"]
+    assert config.trainer_config == sample_config["trainer_config"]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 60ac298 and e674426.

📒 Files selected for processing (3)
  • sleap_nn/config/model_config.py (1 hunks)
  • tests/config/test_model_config.py (1 hunks)
  • tests/config/test_training_job_config.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
tests/config/test_model_config.py

12-12: sleap_nn.config.model_config.UNetConfig imported but unused

Remove unused import

(F401)


13-13: sleap_nn.config.model_config.ConvNextConfig imported but unused

Remove unused import

(F401)


16-16: sleap_nn.config.model_config.SingleInstanceConfig imported but unused

Remove unused import

(F401)


17-17: sleap_nn.config.model_config.CentroidConfig imported but unused

Remove unused import

(F401)


18-18: sleap_nn.config.model_config.CenteredInstanceConfig imported but unused

Remove unused import

(F401)


19-19: sleap_nn.config.model_config.BottomUpConfig imported but unused

Remove unused import

(F401)


20-20: sleap_nn.config.model_config.SingleInstanceConfMapsConfig imported but unused

Remove unused import

(F401)


21-21: sleap_nn.config.model_config.CentroidConfMapsConfig imported but unused

Remove unused import

(F401)


22-22: sleap_nn.config.model_config.CenteredInstanceConfMapsConfig imported but unused

Remove unused import

(F401)


23-23: sleap_nn.config.model_config.BottomUpConfMapsConfig imported but unused

Remove unused import

(F401)


24-24: sleap_nn.config.model_config.PAFConfig imported but unused

Remove unused import

(F401)


43-43: Comparison to None should be cond is None

Replace with cond is None

(E711)


60-60: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)


75-75: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

sleap_nn/config/model_config.py

8-8: enum.Enum imported but unused

Remove unused import: enum.Enum

(F401)

tests/config/test_training_job_config.py

35-35: dataclasses.asdict imported but unused

Remove unused import: dataclasses.asdict

(F401)


220-220: Local variable config is assigned to but never used

Remove assignment to unused variable config

(F841)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
🔇 Additional comments (2)
tests/config/test_training_job_config.py (2)

1-25: LGTM! Well-documented purpose and design philosophy.

The docstring clearly explains the purpose of configuration classes, their serialization capabilities, and the separation of concerns between parameter specification and implementation.


38-57: LGTM! Well-structured test fixture.

The fixture provides a comprehensive sample configuration using proper config objects with all required fields.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
sleap_nn/config/training_job_config.py (2)

91-110: ⚠️ Potential issue

Fix return type and add error handling in to_yaml method.

The method has several issues:

  1. Return type hint is incorrect (should be str).
  2. Error handling is missing.
  3. The method returns None instead of the YAML string.
-    def to_yaml(self, filename: Optional[Text] = None) -> None:
+    def to_yaml(self, filename: Optional[Text] = None) -> str:
         """Serialize and optionally save the configuration to YAML format.
 
         Args:
             filename: Optional path to save the YAML file to. If not provided,
                      the configuration will only be converted to YAML format.
+        Returns:
+            The YAML-encoded string representation of the configuration.
+        Raises:
+            ValueError: If the configuration cannot be serialized or saved.
         """
-        # Convert attrs objects to nested dictionaries
-        config_dict = asdict(self)
+        try:
+            # Convert attrs objects to nested dictionaries
+            config_dict = asdict(self)
 
-        # Handle any special cases (like enums) that need manual conversion
-        if config_dict.get("model", {}).get("backbone_type"):
-            config_dict["model"]["backbone_type"] = self.model.backbone_type
+            # Handle any special cases (like enums) that need manual conversion
+            if config_dict.get("model_config", {}).get("backbone_type"):
+                config_dict["model_config"]["backbone_type"] = self.model_config.backbone_type
 
-        # Create OmegaConf object and save if filename provided
-        conf = OmegaConf.create(config_dict)
-        if filename is not None:
-            OmegaConf.save(conf, filename)
-        return
+            # Create OmegaConf object and save if filename provided
+            conf = OmegaConf.create(config_dict)
+            yaml_str = OmegaConf.to_yaml(conf)
+            if filename is not None:
+                with open(filename, "w") as f:
+                    f.write(yaml_str)
+            return yaml_str
+        except Exception as e:
+            raise ValueError(f"Failed to serialize or save configuration: {e}")

112-125: ⚠️ Potential issue

Fix parameter handling and add file validation.

The function has several issues:

  1. The load_training_config parameter is passed to load_yaml but not used.
  2. File existence validation is missing.
-def load_config(filename: Text, load_training_config: bool = True) -> OmegaConf:
+def load_config(filename: Text, load_training_config: bool = True) -> TrainingJobConfig:
     """Load a training job configuration for a model run.
 
     Args:
         filename: Path to a YAML file or directory containing `training_job.yaml`.
         load_training_config: If `True` (the default), prefer `training_job.yaml` over
             `initial_config.yaml` if it is present in the same folder.
 
     Returns:
-        The parsed `OmegaConf`.
+        The parsed `TrainingJobConfig`.
+    Raises:
+        FileNotFoundError: If the configuration file is not found.
     """
-    return TrainingJobConfig.load_yaml(
-        filename, load_training_config=load_training_config
-    )
+    if not os.path.exists(filename):
+        raise FileNotFoundError(f"Configuration file not found: {filename}")
+
+    if os.path.isdir(filename):
+        config_file = os.path.join(
+            filename,
+            "training_job.yaml" if load_training_config else "initial_config.yaml"
+        )
+        if not os.path.exists(config_file):
+            raise FileNotFoundError(
+                f"Configuration file not found in directory: {config_file}"
+            )
+        filename = config_file
+
+    return TrainingJobConfig.load_yaml(filename)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e674426 and c448afa.

📒 Files selected for processing (1)
  • sleap_nn/config/training_job_config.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
  • GitHub Check: Lint

@gitttt-1234 gitttt-1234 merged commit fda9404 into main Feb 13, 2025
7 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Mar 10, 2025
@coderabbitai coderabbitai bot mentioned this pull request Mar 21, 2025
@gitttt-1234 gitttt-1234 deleted the greg/omegaconf-basic-func branch July 26, 2025 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement a schema for configs

3 participants