Defining hyperparameter spaces, neural architectures, and complex configurations often means writing repetitive boilerplate, dealing with silent validation errors, and struggling to enforce best practices. SpaX is a Pydantic-based configuration framework that eliminates these pain points through declarative, type-safe search space definitions.
Built for ML experimentation but useful anywhere you need robust configuration management, SpaX catches invalid parameter combinations at definition time, enforces constraints automatically, and integrates seamlessly with HPO frameworks like Optuna. Whether you're tuning hyperparameters, exploring architectures, or managing production configs, SpaX reduces bugs and saves time.
What you get:
- Zero boilerplate β One-line migration from Pydantic, automatic space inference from type hints
- Early error detection β Invalid configurations caught at definition time, not during training
- Declarative constraints β Conditional parameters, nested configs, polymorphic fields all type-safe
- Seamless HPO β Direct Optuna integration with all features working automatically
- Iterative refinement β Progressive search space narrowing based on experimental results
- Full serialization β Save and load configurations in JSON, YAML, or TOML
Base installation:
pip install spaxWith optional dependencies:
# YAML serialization support
pip install spax[yaml]
# TOML serialization support
pip install spax[toml]
# Optuna integration
pip install spax[optuna]
# All optional features
pip install spax[all]Requirements:
- Python 3.11+
- Pydantic 2.7+
Define your search space once, get validation, sampling, visualization, and HPO integration automatically:
from typing import Literal
from pydantic import Field
import spax as sp
class ModelConfig(sp.Config):
# Automatic inference from type hints
optimizer: Literal["adam", "sgd", "rmsprop"]
use_scheduler: bool
use_dropout: bool
# Pydantic Field constraints work too
batch_size: int = Field(ge=16, le=128)
# Explicit SpaX spaces for full control
learning_rate: float = sp.Float(ge=1e-5, le=1e-1, distribution="log")
# Conditional spaces - parameters that depend on others
num_layers: int = sp.Conditional(
sp.FieldCondition("use_scheduler", sp.EqualsTo(True)),
true=sp.Int(ge=6, le=12), # Deep networks with scheduler
false=sp.Int(ge=2, le=6), # Shallow networks without
)
# dropout_rate only exists when use_dropout=True
dropout_rate: float = sp.Conditional(
sp.FieldCondition("use_dropout", sp.EqualsTo(True)),
true=sp.Float(ge=0.1, le=0.5),
false=0.0,
)Visualize your search space:
print(ModelConfig.get_tree())
# ModelConfig
# ββ optimizer: Categorical
# β ββ 'adam'
# β ββ 'sgd'
# β ββ 'rmsprop'
# ββ use_scheduler: Categorical
# β ββ True
# β ββ False
# ββ use_dropout: Categorical
# β ββ True
# β ββ False
# ββ batch_size: Int([16, 128], uniform)
# ββ learning_rate: Float([1e-05, 0.1], log)
# ββ num_layers: Conditional (if use_scheduler == True)
# β ββ true: Int([6, 12], uniform)
# β ββ false: Int([2, 6], uniform)
# ββ dropout_rate: Conditional (if use_dropout == True)
# ββ true: Float([0.1, 0.5], uniform)
# ββ false: 0.0Random sampling for testing:
config = ModelConfig.random(seed=42)
print(config)
# ModelConfig(optimizer='rmsprop', use_scheduler=True, use_dropout=True, batch_size=97,
# learning_rate=2.788e-05, num_layers=11, dropout_rate=0.155815)Serialization: (save/load in multiple formats)
yaml_str = config.model_dump_yaml()
loaded = ModelConfig.model_validate_yaml(yaml_str)
# Also: model_dump_json/toml, model_validate_json/tomlIterative refinement: (narrow the search space based on results)
config_v2 = ModelConfig.random(
seed=42,
override={
"learning_rate": {"ge": 1e-4, "le": 1e-2}, # Focus on promising region
"optimizer": "adam", # Lock to best optimizer
},
)
print(config_v2.learning_rate) # Now in [1e-4, 1e-2] range
# See the full override template as a reference:
print(ModelConfig.get_override_template())
# {
# "optimizer": ["adam", "sgd", "rmsprop"],
# "use_scheduler": ["True", "False"],
# "dropout_rate": {"true": {"ge": 0.1, "le": 0.5}},
# "batch_size": {"ge": 16, "le": 128},
# "learning_rate": {"ge": 1e-05, "le": 0.1},
# "num_layers": {"true": {"ge": 6, "le": 12}, "false": {"ge": 2, "le": 6}},
# "use_dropout": ["True", "False"],
# }Seamless Optuna integration:
import optuna
def objective(trial: optuna.Trial) -> float:
config = ModelConfig.from_trial(trial)
# Your training logic here
score = ...
return score
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)
# Retrieve best config
best = ModelConfig.from_trial(study.best_trial)
# See your parameter names(What optuna.Trial saw):
print(ModelConfig.get_parameter_names())
# [
# "ModelConfig.batch_size",
# "ModelConfig.learning_rate",
# "ModelConfig.optimizer",
# "ModelConfig.use_dropout",
# "ModelConfig.dropout_rate::true_branch",
# "ModelConfig.use_scheduler",
# "ModelConfig.num_layers::true_branch",
# "ModelConfig.num_layers::false_branch",
# ]What this demonstrates: Type-safe configs with automatic inference, conditional parameters, visualization, random sampling, serialization, iterative refinement, and one-line HPO integration β all working together seamlessly.
SpaX infers search spaces from type hints and Pydantic Field constraints with zero extra code:
import spax as sp
from typing import Literal
from pydantic import Field
class InferredConfig(sp.Config):
# Literal β CategoricalSpace
activation: Literal["relu", "gelu", "silu"]
# bool β CategoricalSpace([True, False])
use_norm: bool
# Field with bounds β NumericSpace
hidden_dim: int = Field(gt=64, lt=1024)
learning_rate: float = Field(ge=1e-5, le=1e-1)When automatic inference isn't enough, use explicit spaces for full control:
class ExplicitConfig(sp.Config):
# Log distribution for learning rates
learning_rate: float = sp.Float(ge=1e-5, le=1e-1, distribution="log")
# Weighted categorical choices
optimizer: str = sp.Categorical(
[
sp.Choice("adam", weight=3.0), # 3x more likely
sp.Choice("sgd", weight=1.0),
sp.Choice("rmsprop", weight=1.0),
]
)Define parameters that only exist or change based on other parameters. SpaX handles dependency ordering and validation automatically.
class ConditionalConfig(sp.Config):
use_augmentation: bool
optimizer: str = sp.Categorical(["adam", "sgd"])
# Only exists when use_augmentation=True
aug_strength: float = sp.Conditional(
sp.FieldCondition("use_augmentation", sp.EqualsTo(True)),
true=sp.Float(ge=0.1, le=0.9),
false=0.0,
)
# SGD-specific parameter
momentum: float = sp.Conditional(
sp.FieldCondition("optimizer", sp.EqualsTo("sgd")),
true=sp.Float(ge=0.0, le=0.99),
false=0.0,
)Available conditions:
- Equality:
EqualsTo,NotEqualsTo - Membership:
In,NotIn - Comparison:
LargerThan,SmallerThan(withor_equalsparameter) - Type checking:
IsInstance - Logical:
And,Or,Not - Custom:
Lambda,MultiFieldLambdaCondition
Complex logic with composite conditions:
class AdvancedConditional(sp.Config):
model_size: str = sp.Categorical(["small", "large"])
dataset_size: str = sp.Categorical(["small", "large"])
# Large batch only when BOTH model and dataset are large
batch_size: int = sp.Conditional(
sp.And(
[
sp.FieldCondition("model_size", sp.EqualsTo("large")),
sp.FieldCondition("dataset_size", sp.EqualsTo("large")),
]
),
true=sp.Int(ge=128, le=512),
false=sp.Int(ge=16, le=64),
)Build complex configurations from smaller, reusable components through nesting, inheritance, and polymorphism.
Nesting - Compose configs from subconfigs:
class OptimizerConfig(sp.Config):
name: str = sp.Categorical(["adam", "sgd"])
learning_rate: float = sp.Float(ge=1e-5, le=1e-2, distribution="log")
class ModelConfig(sp.Config):
num_layers: int = sp.Int(ge=2, le=12)
hidden_dim: int = sp.Int(ge=128, le=512)
class ExperimentConfig(sp.Config):
model: ModelConfig
optimizer: OptimizerConfig
batch_size: int = sp.Int(ge=16, le=128)
# Access nested fields naturally
config = ExperimentConfig.random(seed=42)
print(config.optimizer.learning_rate)Inheritance - Create specialized variants:
class BaseModel(sp.Config):
num_layers: int = sp.Int(ge=1, le=12)
hidden_dim: int = sp.Int(ge=64, le=512)
class ResNet(BaseModel):
# Add ResNet-specific parameters
use_bottleneck: bool
stride: int = sp.Categorical([1, 2])
# Override parent's hidden_dim with different range
hidden_dim: int = sp.Int(ge=128, le=2048)Polymorphism - Union types for flexible architectures:
class CNNEncoder(sp.Config):
num_conv_layers: int = sp.Int(ge=2, le=8)
kernel_size: int = sp.Categorical([3, 5, 7])
class TransformerEncoder(sp.Config):
num_layers: int = sp.Int(ge=2, le=12)
num_heads: int = sp.Int(ge=4, le=16)
class FlexibleModel(sp.Config):
# Can be either CNN or Transformer!
encoder: CNNEncoder | TransformerEncoder
output_dim: int = sp.Int(ge=10, le=1000)
# SpaX handles type discrimination automatically
config = FlexibleModel.random(seed=42)
if isinstance(config.encoder, TransformerEncoder):
print(f"Transformer with {config.encoder.num_heads} heads")Conditional logic on nested fields using dotted paths:
class DeepConfig(sp.Config):
model: ModelConfig
# Condition on nested field
use_gradient_checkpointing: bool = sp.Conditional(
sp.FieldCondition("model.num_layers", sp.LargerThan(8)),
true=True,
false=False,
)One-line integration with Optuna. All SpaX features work automatically:
import optuna
import spax as sp
class HPOConfig(sp.Config):
learning_rate: float = sp.Float(ge=1e-5, le=1e-1, distribution="log")
batch_size: int = sp.Int(ge=16, le=128)
num_layers: int = sp.Int(ge=2, le=12)
def objective(trial: optuna.Trial) -> float:
# One line - that's it!
config = HPOConfig.from_trial(trial)
# Your training code
model = create_model(config)
score = train_and_evaluate(model, config)
return score
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)
# Get the best configuration
best_config = HPOConfig.from_trial(study.best_trial)All SpaX features work automatically with Optuna:
- β Conditional parameters - dependencies handled correctly
- β Nested configs - hierarchical parameter naming prevents conflicts
- β Polymorphic fields - different config types explored automatically
- β Log distributions - passed through to Optuna's samplers
Custom samplers via the Sampler interface:
from spax.samplers import Sampler
class CustomSampler(Sampler):
def suggest_int(
self,
name: str,
low: int,
high: int,
low_inclusive: bool,
high_inclusive: bool,
distribution: Literal["log", "uniform"],
) -> int:
# Your custom logic
return ...
def suggest_float(
self,
name: str,
low: float,
high: float,
low_inclusive: bool,
high_inclusive: bool,
distribution: Literal["log", "uniform"],
) -> float:
return ...
def suggest_categorical(
self, name: str, choices: list[Any], weights: list[float]
) -> Any:
return ...
config = MyConfig.sample(CustomSampler(), override=...)Narrow search spaces progressively based on experimental results without modifying your config definition.
class SearchConfig(sp.Config):
learning_rate: float = sp.Float(ge=1e-5, le=1e-1, distribution="log")
num_layers: int = sp.Int(ge=2, le=12)
optimizer: str = sp.Categorical(["adam", "sgd", "rmsprop"])
# Initial broad search
for i in range(100):
config = SearchConfig.random()
score = train(config)
# ... track results
# After analysis: focus on promising regions
override = {
"learning_rate": {"ge": 1e-4, "le": 1e-2}, # Narrow range
"optimizer": "adam", # Fix to best
# num_layers untouched - keep exploring
}
# Refined search
for i in range(100):
config = SearchConfig.random(override=override)
score = train(config)
# Works with Optuna too
config = SearchConfig.from_trial(trial, override=override)Get the override template as a reference:
template = SearchConfig.get_override_template()
# Save to file for manual editing
import json
with open("override.json", "w") as f:
json.dump(template, f, indent=2)Save and load configurations in multiple formats. Nested and polymorphic configs are handled automatically with type discriminators:
config = MyConfig.random(seed=42)
# Save
json_str = config.model_dump_json()
yaml_str = config.model_dump_yaml() # requires PyYAML
toml_str = config.model_dump_toml() # requires tomli-w
# Load
loaded = MyConfig.model_validate_json(json_str)
loaded = MyConfig.model_validate_yaml(yaml_str)
loaded = MyConfig.model_validate_toml(toml_str)
# Works with files too
with open("config.yaml", "w") as f:
f.write(config.model_dump_yaml())
with open("config.yaml") as f:
config = MyConfig.model_validate_yaml(f.read())Implement the Sampler interface for custom optimization algorithms or integration with other HPO frameworks:
from typing import Any, Literal
from spax.samplers import Sampler
class GridSampler(Sampler):
"""Example: Simple grid search sampler."""
def __init__(self, grid_points: dict) -> None:
self.grid_points = grid_points
self.current_idx = 0
self._record = {}
@property
def record(self) -> dict:
return self._record.copy()
def suggest_int(
self,
name: str,
low: int,
high: int,
low_inclusive: bool,
high_inclusive: bool,
distribution: Literal["log", "uniform"],
) -> int:
values = self.grid_points.get(name, [low, high])
value = values[self.current_idx % len(values)]
self._record[name] = value
return value
def suggest_float(
self,
name: str,
low: float,
high: float,
low_inclusive: bool,
high_inclusive: bool,
distribution: Literal["log", "uniform"],
) -> float:
values = self.grid_points.get(name, [low, high])
value = values[self.current_idx % len(values)]
self._record[name] = value
return value
def suggest_categorical(
self, name: str, choices: list[Any], weights: list[float]
) -> Any:
value = choices[self.current_idx % len(choices)]
self._record[name] = value
return value
# Use custom sampler
grid = {
"learning_rate": [1e-4, 1e-3, 1e-2],
"batch_size": [16, 32, 64],
}
sampler = GridSampler(grid)
config = MyConfig.sample(sampler)For complex dependencies that can't be expressed with simple conditions:
class ResourceConfig(sp.Config):
num_gpus: int = sp.Int(ge=1, le=8)
batch_size_per_gpu: int = sp.Int(ge=8, le=128)
# Condition: total batch size must be reasonable
use_gradient_accumulation: bool = sp.Conditional(
sp.MultiFieldLambdaCondition(
["num_gpus", "batch_size_per_gpu"],
lambda data: data["num_gpus"] * data["batch_size_per_gpu"] > 256,
),
true=True,
false=False,
)
accumulation_steps: int = sp.Conditional(
sp.FieldCondition("use_gradient_accumulation", sp.EqualsTo(True)),
true=sp.Int(ge=2, le=8),
false=1,
)Dotted paths for nested fields:
class AdvancedConfig(sp.Config):
model: ModelConfig
optimizer: OptimizerConfig
# Custom logic across nested fields
use_mixed_precision: bool = sp.Conditional(
sp.MultiFieldLambdaCondition(
["model.num_layers", "optimizer.learning_rate"],
lambda data: (
data["model.num_layers"] > 6
and data["optimizer.learning_rate"] < 1e-3
),
),
true=True,
false=False,
)Configs can be nested arbitrarily deep for complex systems:
class AttentionConfig(sp.Config):
num_heads: int = sp.Int(ge=1, le=16)
head_dim: int = sp.Int(ge=32, le=128)
class EncoderLayerConfig(sp.Config):
attention: AttentionConfig
feedforward_dim: int = sp.Int(ge=256, le=4096)
class ModelConfig(sp.Config):
encoder: EncoderLayerConfig
num_layers: int = sp.Int(ge=2, le=12)
# Condition on deeply nested field
use_gradient_checkpointing: bool = sp.Conditional(
sp.FieldCondition("encoder.attention.num_heads", sp.LargerThan(8)),
true=True,
false=False,
)
# Access 3 levels deep
config = ModelConfig.random(seed=42)
print(config.encoder.attention.num_heads)Define absolute bounds once, refine via overrides:
# config_space.py - Define the absolute search space once
class TrainingConfig(sp.Config):
learning_rate: float = sp.Float(ge=1e-6, le=1e-1, distribution="log")
batch_size: int = sp.Int(ge=8, le=512)
num_layers: int = sp.Int(ge=1, le=20)
# ... more parameters
# DON'T modify the space definition for experiments
# Instead, use overrides to narrow ranges
# experiments/phase1_broad.json
{
"learning_rate": {"ge": 1e-5, "le": 1e-2},
"num_layers": {"ge": 2, "le": 12}
}
# experiments/phase2_refined.json
{
"learning_rate": {"ge": 1e-4, "le": 5e-4},
"num_layers": 6,
"batch_size": {"ge": 32, "le": 128}
}
# Load override and use
with open("experiments/phase2_refined.json") as f:
override = json.load(f)
config = TrainingConfig.random(seed=42, override=override)
# Or with Optuna
config = TrainingConfig.from_trial(trial, override=override)Version your configs:
# Generate deterministic hash of search space
space_hash = TrainingConfig.get_space_hash()
# This hash changes when the search space structure changes
print(f"Search space version: {space_hash[:8]}")
# Include in experiment metadata
new_space_hash = TrainingConfig.get_space_hash(override={...})
print(f"New search space version: {new_space_hash[:8]}")Validate before expensive operations:
try:
MyConfig.get_tree(override)
except Exception as e:
# Log failure
logger.error(f"Invalid override: {e}")Jupyter notebooks demonstrating SpaX features with runnable code and explanations:
| Notebook | Description |
|---|---|
| 00 - Quickstart | One-line migration from Pydantic, automatic inference, random sampling, and basic override system |
| 01 - Conditional Parameters | Simple and composite conditions, nested field paths, multi-field lambda conditions |
| 02 - Nested & Modular Configs | Nesting, inheritance, polymorphic fields, and deep hierarchies |
| 03 - Serialization | JSON/YAML/TOML serialization, handling nested configs, error handling, reproducibility workflow |
| 04 - HPO with Optuna | Seamless Optuna integration, conditionals with HPO, nested configs in optimization, complete workflow |
Each notebook is self-contained with explanations, runnable code, real-world use cases, and best practices.
See the examples README for setup instructions, recommended learning paths, and more details.
Experiment Tracking & Visualization:
- Dedicated experiment tracking API
- Rich visualization suite for search space exploration
- Parameter correlation heatmaps and dominance charts
- Score vs parameter value plots
- Automatic search space pruning based on results
Enhanced HPO Features:
- Built-in random search algorithm optimized for large spaces
- Parallel experiment execution API
- Multi-objective optimization support
- Distributed workload management
- Ray Tune integration
- PyTorch Lightning integration for streamlined training
- Comprehensive documentation (Read the Docs)
- Automatic hyperparameter importance analysis
- More search space visualization tools
- Config diff and comparison utilities
- Template library for common ML workflows
- Enhanced error messages and debugging tools
Have a feature request or use case we should prioritize? Open an issue or join the discussion!
Contributions are welcome! SpaX is in active development and there are many ways to help:
Report bugs or request features:
- Open an issue with a clear description
- Include minimal reproducible examples for bugs
- Search existing issues to avoid duplicates
Contribute code:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Make your changes with tests and documentation
- Run the test suite:
pytest - Check code quality:
black .andruff check . - Commit your changes (
git commit -m 'Add amazing feature') - Push to your branch (
git push origin feature/amazing-feature) - Open a Pull Request
Development setup:
# Clone the repository
git clone https://github.com/keyhankamyar/SpaX.git
cd SpaX
# Install in development mode with all dependencies
pip install -e ".[dev,all]"
# Run tests
pytest
# Check code quality
black --check .
ruff check .
mypy spaxAreas where help is appreciated:
- Additional examples and tutorials
- Documentation
- Integration with other HPO frameworks
- Bug reports and fixes
- Follow existing code style (Black + Ruff)
- Add type hints for all functions
- Write tests for new features (aim for >90% coverage)
- Update documentation for user-facing changes
- Keep PRs focused and atomic
See CONTRIBUTING.md for more information.
If you use SpaX in your research or projects, please cite it:
@software{spax2025,
author = {Kamyar, Keyhan},
title = {SpaX: Declarative search-space definition & exploration},
year = {2025},
version = {0.2.0},
url = {https://github.com/keyhankamyar/SpaX}
}See CITATION.cff for more citation formats.
SpaX is released under the MIT License.