diff --git a/orca_python/model_selection/__init__.py b/orca_python/model_selection/__init__.py new file mode 100644 index 0000000..f4e5069 --- /dev/null +++ b/orca_python/model_selection/__init__.py @@ -0,0 +1,18 @@ +"""Model selection and estimator loading utilities.""" + +from .loaders import get_classifier_by_name, load_classifier +from .validation import ( + check_for_random_state, + is_ensemble, + is_searchcv, + prepare_param_grid, +) + +__all__ = [ + "get_classifier_by_name", + "load_classifier", + "check_for_random_state", + "is_ensemble", + "is_searchcv", + "prepare_param_grid", +] diff --git a/orca_python/model_selection/loaders.py b/orca_python/model_selection/loaders.py new file mode 100644 index 0000000..7dc0568 --- /dev/null +++ b/orca_python/model_selection/loaders.py @@ -0,0 +1,158 @@ +"""Model selection and estimator loading utilities.""" + +from importlib import import_module + +from sklearn.model_selection import GridSearchCV, StratifiedKFold + +from orca_python.metrics.utils import load_metric_as_scorer +from orca_python.model_selection.validation import ( + is_searchcv, + prepare_param_grid, +) + +_ORCA_CLASSIFIERS = { + "NNOP": "orca_python.classifiers.NNOP", + "NNPOM": "orca_python.classifiers.NNPOM", + "OrdinalDecomposition": "orca_python.classifiers.OrdinalDecomposition", + "REDSVM": "orca_python.classifiers.REDSVM", + "SVOREX": "orca_python.classifiers.SVOREX", +} + +_SKLEARN_CLASSIFIERS = { + "SVC": "sklearn.svm.SVC", + "LogisticRegression": "sklearn.linear_model.LogisticRegression", + "RandomForestClassifier": "sklearn.ensemble.RandomForestClassifier", +} + +_CLASSIFIERS = {**_ORCA_CLASSIFIERS, **_SKLEARN_CLASSIFIERS} + + +def get_classifier_by_name(classifier_name): + """Return a classifier not instantiated matching a given input name. + + Parameters + ---------- + classifier_name : str + Name of the classification algorithm being employed. + + Returns + ------- + classifier : object + Returns a classifier, either from a scikit-learn module, or from a + module of this framework. + + Raises + ------ + ValueError + If an unknown classifier name is provided. + + Examples + -------- + >>> get_classifier_by_name("SVOREX") + + >>> get_classifier_by_name("REDSVM") + + >>> get_classifier_by_name("SVC") + + + """ + if classifier_name not in _CLASSIFIERS: + raise ValueError(f"Unknown classifier '{classifier_name}'.") + + module_path, class_name = _CLASSIFIERS[classifier_name].rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) + + +def load_classifier( + classifier_name, + random_state=None, + n_jobs=1, + cv_n_folds=3, + cv_metric="mae", + param_grid=None, +): + """Return a fully configured classifier, optionally with cross-validation. + + This function loads a classifier, configures its parameters, and optionally + sets up cross-validation if multiple parameter values are provided. + + Parameters + ---------- + classifier_name : str + Name of the classification algorithm being employed. + + random_state : int, RandomState instance or None, optional (default=None) + Seed for reproducible randomization in model training and probability + estimation. + + n_jobs : int, optional (default=1) + Number of parallel processing cores for computational tasks. + + cv_n_folds : int, optional (default=3) + Number of folds for cross-validation (only used if applicable). + + cv_metric : str or callable, optional (default="mae") + Evaluation metric for cross-validation performance assessment. + + param_grid : dict or None, optional (default=None) + Hyperparameter grid. If multiple values are given, cross-validation will be applied. + + Returns + ------- + classifier : object + The initialized classifier object, optionally wrapped in GridSearchCV. + + Raises + ------ + ValueError + If an unknown classifier name is provided or if an invalid parameter + is specified for the classifier. + + Examples + -------- + >>> from orca_python.model_selection import load_classifier + >>> clf = load_classifier("SVC", random_state=0) + >>> clf + SVC() + >>> clf_cv = load_classifier("SVC", random_state=0, param_grid={"C": [0.1, 1.0]}) + >>> clf_cv.__class__.__name__ + 'GridSearchCV' + + """ + classifier_cls = get_classifier_by_name(classifier_name) + + if param_grid is None: + return classifier_cls() + + param_grid = prepare_param_grid(classifier_cls, param_grid, random_state) + + if is_searchcv(param_grid): + scorer = ( + load_metric_as_scorer(cv_metric) + if isinstance(cv_metric, str) + else cv_metric + ) + cv = StratifiedKFold( + n_splits=cv_n_folds, shuffle=True, random_state=random_state + ) + + return GridSearchCV( + estimator=classifier_cls(), + param_grid=param_grid, + scoring=scorer, + n_jobs=n_jobs, + cv=cv, + error_score="raise", + ) + + try: + classifier = classifier_cls(**param_grid) + classifier.assigned_params_ = param_grid + return classifier + except TypeError as e: + invalid_param = str(e).split("'")[1] + raise ValueError( + f"Invalid parameter '{invalid_param}' for classifier" + f" '{classifier_name}'." + ) diff --git a/orca_python/model_selection/tests/__init__.py b/orca_python/model_selection/tests/__init__.py new file mode 100644 index 0000000..5d589f0 --- /dev/null +++ b/orca_python/model_selection/tests/__init__.py @@ -0,0 +1,3 @@ +"""Tests for model selection module.""" + +__all__ = [] diff --git a/orca_python/model_selection/tests/test_loaders.py b/orca_python/model_selection/tests/test_loaders.py new file mode 100644 index 0000000..3c91a64 --- /dev/null +++ b/orca_python/model_selection/tests/test_loaders.py @@ -0,0 +1,123 @@ +"Tests for model selection and estimator loading utilities." + +import pytest +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import GridSearchCV +from sklearn.svm import SVC + +from orca_python.classifiers import NNOP, NNPOM, REDSVM, SVOREX, OrdinalDecomposition +from orca_python.metrics import load_metric_as_scorer +from orca_python.model_selection import get_classifier_by_name, load_classifier +from orca_python.testing import TEST_RANDOM_STATE + + +def test_get_classifier_by_name_correct(): + """Test that get_classifier_by_name returns the correct classifier.""" + # ORCA classifiers + assert get_classifier_by_name("NNOP") == NNOP + assert get_classifier_by_name("NNPOM") == NNPOM + assert get_classifier_by_name("OrdinalDecomposition") == OrdinalDecomposition + assert get_classifier_by_name("REDSVM") == REDSVM + assert get_classifier_by_name("SVOREX") == SVOREX + + # Scikit-learn classifiers + assert get_classifier_by_name("SVC") == SVC + assert get_classifier_by_name("LogisticRegression") == LogisticRegression + + +def test_get_classifier_by_name_incorrect(): + """Test that get_classifier_by_name raises ValueError for unknown classifiers.""" + with pytest.raises(ValueError, match="Unknown classifier 'RandomForest'"): + get_classifier_by_name("RandomForest") + + with pytest.raises(ValueError, match="Unknown classifier 'SVR'"): + get_classifier_by_name("SVR") + + +def test_load_classifier_without_parameters(): + """Test that load_classifier correctly instantiates classifiers without + parameters.""" + assert isinstance(load_classifier("NNOP"), NNOP) + assert isinstance(load_classifier("NNPOM"), NNPOM) + assert isinstance(load_classifier("OrdinalDecomposition"), OrdinalDecomposition) + assert isinstance(load_classifier("REDSVM"), REDSVM) + assert isinstance(load_classifier("SVOREX"), SVOREX) + assert isinstance(load_classifier("SVC"), SVC) + assert isinstance(load_classifier("LogisticRegression"), LogisticRegression) + + +def test_load_classifier_with_parameters(): + """Test that load_classifier correctly instantiates classifiers with + parameters.""" + param_grid = { + "epsilon_init": 0.5, + "n_hidden": 10, + "max_iter": 500, + "lambda_value": 0.01, + } + classifier = load_classifier("NNPOM", param_grid=param_grid) + assert isinstance(classifier, NNPOM) + assert classifier.epsilon_init == 0.5 + assert classifier.n_hidden == 10 + assert classifier.max_iter == 500 + assert classifier.lambda_value == 0.01 + + +def test_load_classifier_with_searchcv(): + """Test that load_classifier correctly returns a GridSearchCV when param_grid has multiple values.""" + param_grid = {"C": [0.1, 1.0], "probability": "True"} + + classifier = load_classifier( + "SVC", + param_grid=param_grid, + random_state=TEST_RANDOM_STATE, + cv_n_folds=5, + cv_metric="mae", + n_jobs=1, + ) + + expected_param_grid = { + "C": [0.1, 1.0], + "probability": ["True"], + "random_state": [TEST_RANDOM_STATE], + } + + assert isinstance(classifier, GridSearchCV) + assert classifier.cv.n_splits == 5 + assert classifier.param_grid == expected_param_grid + + +def test_load_classifier_with_ensemble_method(): + """Test that load_classifier correctly handles ensemble methods.""" + param_grid = { + "dtype": "ordered_partitions", + "decision_method": "frank_hall", + "base_classifier": "SVC", + "parameters": { + "C": [0.01, 0.1, 1, 10], + "gamma": [0.01, 0.1, 1, 10], + "probability": ["True"], + }, + } + classifier = load_classifier( + classifier_name="OrdinalDecomposition", + param_grid=param_grid, + n_jobs=10, + cv_n_folds=3, + cv_metric=load_metric_as_scorer("mae"), + random_state=TEST_RANDOM_STATE, + ) + assert isinstance(classifier, GridSearchCV) + assert classifier.param_grid["decision_method"] == [param_grid["decision_method"]] + assert classifier.param_grid["base_classifier"] == [param_grid["base_classifier"]] + for params in classifier.param_grid["parameters"]: + assert params["random_state"] == TEST_RANDOM_STATE + assert classifier.cv.n_splits == 3 + + +def test_load_classifier_with_invalid_param(): + """Test that load_classifier raises error with invalid parameter key.""" + error_msg = "Invalid parameter 'T' for classifier 'SVC'." + + with pytest.raises(ValueError, match=error_msg): + load_classifier(classifier_name="SVC", param_grid={"T": 0.1}) diff --git a/orca_python/model_selection/tests/test_validation.py b/orca_python/model_selection/tests/test_validation.py new file mode 100644 index 0000000..ac40192 --- /dev/null +++ b/orca_python/model_selection/tests/test_validation.py @@ -0,0 +1,337 @@ +"""Tests for parameter grid preparation and validation utilities.""" + +import pytest +from sklearn.linear_model import LogisticRegression +from sklearn.svm import SVC + +from orca_python.classifiers import ( + NNOP, + NNPOM, + REDSVM, + SVOREX, + OrdinalDecomposition, +) +from orca_python.model_selection.validation import ( + _add_random_state, + _normalize_param_grid, + _prepare_parameters_for_ensemble, + check_for_random_state, + is_ensemble, + is_searchcv, + prepare_param_grid, +) +from orca_python.testing import TEST_RANDOM_STATE + + +@pytest.mark.parametrize( + "estimator, expected", + [ + (None, False), + (NNOP, False), + (NNPOM, False), + (OrdinalDecomposition, False), + (REDSVM, False), + (SVOREX, False), + (SVC, True), + (LogisticRegression, True), + ], +) +def test_check_for_random_state(estimator, expected): + """Test that check_for_random_state correctly identifies classifiers that + use random state.""" + assert check_for_random_state(estimator) is expected + + +@pytest.mark.parametrize( + "param_grid, expected", + [ + ({"C": 0.1}, False), + ({"C": [0.1, 1]}, False), + ({"base_classifier": "SVC"}, True), + ( + { + "base_classifier": "SVC", + "parameters": {"C": [1], "gamma": [0.1]}, + }, + True, + ), + ], +) +def test_is_ensemble(param_grid, expected): + """Test that is_ensemble correctly identifies ensemble methods with + various parameter configurations.""" + assert is_ensemble(param_grid) is expected + + +@pytest.mark.parametrize( + "param_grid, expected", + [ + ({"C": 0.1}, False), + ({"C": [0.1, 1]}, True), + ({"C": [0.1], "kernel": "linear"}, False), + ({"kernel_type": 0, "c": [0.1, 1], "k": [0.1, 1], "t": 0.001}, True), + ( + { + "base_classifier": "SVC", + "parameters": {"C": 0.01, "gamma": 0.01}, + }, + False, + ), + ( + { + "base_classifier": "SVC", + "parameters": {"C": [0.01, 1], "gamma": [0.01, 1]}, + }, + True, + ), + ( + { + "base_classifier": "SVC", + "dtype": ["ordered", "unordered"], + "parameters": None, + }, + True, + ), + ], +) +def test_is_searchcv(param_grid, expected): + """Test that is_searchcv correctly identifies cross-validation cases with + various parameter configurations.""" + assert is_searchcv(param_grid) is expected + + +def test_is_searchcv_invalid_input(): + """Test that is_searchcv returns False when param_grid is not a dictionary.""" + assert not is_searchcv(None) + assert not is_searchcv("not a dict") + assert not is_searchcv(["param", "values"]) + + +@pytest.mark.parametrize( + "param_grid, expected", + [ + ( + { + "base_classifier": "SVC", + "parameters": [{"C": 1, "gamma": 0.1}, {"C": 10, "gamma": 1.0}], + }, + True, + ), + ( + { + "base_classifier": "SVC", + "parameters": [{"C": 1, "gamma": 0.1}], + }, + False, + ), + ( + { + "base_classifier": "SVC", + "parameters": None, + "dtype": ["ordered", "unordered"], + }, + True, + ), + ( + { + "base_classifier": "SVC", + "parameters": [{"C": 1}], + "dtype": ["ordered", "unordered"], + }, + True, + ), + ], +) +def test_is_searchcv_with_ensemble_variants(param_grid, expected): + """Test that is_searchcv detects cross-validation when ensemble variants are used.""" + assert is_searchcv(param_grid) is expected + + +def test_prepare_param_grid_no_cv(): + """Test that prepare_param_grid correctly handles non-cross-validation case.""" + param_grid = {"C": 0.1} + result = prepare_param_grid(SVC, param_grid, TEST_RANDOM_STATE) + assert result == {"C": 0.1, "random_state": TEST_RANDOM_STATE} + + +def test_prepare_param_grid_with_cv(): + """Test that prepare_param_grid correctly handles cross-validation case.""" + param_grid = {"C": [0.1, 1.0]} + result = prepare_param_grid(SVC, param_grid, TEST_RANDOM_STATE) + assert result == {"C": [0.1, 1.0], "random_state": [TEST_RANDOM_STATE]} + + +def test_prepare_parame_grid_mixed(): + """Test that prepare_param_grid correctly handles mixed single and multiple + parameters.""" + param_grid = {"C": [0.1, 1.0], "gamma": 0.1} + result = prepare_param_grid(SVC, param_grid, TEST_RANDOM_STATE) + assert result == { + "C": [0.1, 1.0], + "gamma": [0.1], + "random_state": [TEST_RANDOM_STATE], + } + + +def test_prepare_param_grid_without_random_state(): + """Test that prepare_param_grid correctly handles cases without specifying a + random seed.""" + param_grid = {"C": [0.1, 1.0]} + result = prepare_param_grid(SVC, param_grid) + assert "random_state" in result + + +def test_prepare_param_grid_simple_ensemble(): + """Test that prepare_param_grid correctly handles simple ensemble methods.""" + param_grid = { + "dtype": "OrderedPartitions", + "base_classifier": "SVC", + "parameters": {"C": [1], "gamma": [1]}, + } + result = prepare_param_grid(OrdinalDecomposition, param_grid, TEST_RANDOM_STATE) + assert result["parameters"] == { + "C": 1, + "gamma": 1, + "random_state": TEST_RANDOM_STATE, + } + + +def test_prepare_param_grid_cv_ensemble(): + """Test that prepare_param_grid correctly handles ensemble methods with + cross-validation for base_classifier.""" + param_grid = { + "dtype": "OrderedPartitions", + "base_classifier": "SVC", + "parameters": {"C": [1, 10], "gamma": [1, 10], "probability": ["True"]}, + } + prepared_params = prepare_param_grid( + OrdinalDecomposition, param_grid, TEST_RANDOM_STATE + ) + expected_params = { + "dtype": ["OrderedPartitions"], + "base_classifier": ["SVC"], + "parameters": [ + { + "C": 1, + "gamma": 1, + "probability": True, + "random_state": TEST_RANDOM_STATE, + }, + { + "C": 1, + "gamma": 10, + "probability": True, + "random_state": TEST_RANDOM_STATE, + }, + { + "C": 10, + "gamma": 1, + "probability": True, + "random_state": TEST_RANDOM_STATE, + }, + { + "C": 10, + "gamma": 10, + "probability": True, + "random_state": TEST_RANDOM_STATE, + }, + ], + } + assert prepared_params == expected_params + + +def test_prepare_param_grid_invalid_input(): + """Test that prepare_param_grid raises error with invalid input.""" + with pytest.raises(ValueError, match="param_grid must be a dictionary"): + prepare_param_grid(None, "not a dict") + + +def test_add_random_state(): + """Test that _add_random_state adds random_state if missing.""" + param_grid = {"C": 1.0} + updated = _add_random_state(SVC, param_grid.copy(), TEST_RANDOM_STATE) + assert updated["random_state"] == TEST_RANDOM_STATE + + param_grid = {"C": 1.0, "random_state": 999} + updated = _add_random_state(SVC, param_grid.copy(), TEST_RANDOM_STATE) + assert updated["random_state"] == 999 + + param_grid = {"C": 1.0} + updated = _add_random_state(NNOP, param_grid.copy(), TEST_RANDOM_STATE) + assert "random_state" not in updated + + +def test_normalize_param_grid(): + """Test that _normalize_param_grid wraps all scalar values into lists.""" + param_grid = {"C": 1.0, "kernel": "linear"} + normalized = _normalize_param_grid(param_grid) + assert normalized == {"C": [1.0], "kernel": ["linear"]} + + param_grid = {"C": [0.1, 1.0], "gamma": [0.01]} + normalized = _normalize_param_grid(param_grid) + assert normalized == param_grid + + +def test_prepare_parameters_for_ensemble(): + """Test that _prepare_parameters_for_ensemble correctly prepares parameters + for ensemble methods.""" + param_grid = { + "base_classifier": "SVC", + "parameters": {"C": [1, 10], "gamma": [0.1, 1]}, + } + result = _prepare_parameters_for_ensemble(param_grid, TEST_RANDOM_STATE) + assert isinstance(result["parameters"], list) + assert len(result["parameters"]) == 4 + assert all("random_state" in d for d in result["parameters"]) + + param_grid = {"C": [1, 10], "gamma": [0.1, 1]} + result = _prepare_parameters_for_ensemble(param_grid, TEST_RANDOM_STATE) + assert result == param_grid + + +def test_prepare_parameters_for_ensemble_adds_random_state(): + """Test that _prepare_parameters_for_ensemble adds random_state if supported.""" + param_grid = { + "base_classifier": "SVC", + "parameters": {"C": [1], "gamma": [0.1]}, + } + result = _prepare_parameters_for_ensemble(param_grid, TEST_RANDOM_STATE) + + assert all("random_state" in d for d in result["parameters"]) + assert all(d["random_state"] == TEST_RANDOM_STATE for d in result["parameters"]) + + param_grid = { + "base_classifier": "LogisticRegression", + "parameters": {"C": [1.0], "penalty": ["l2"]}, + } + result = _prepare_parameters_for_ensemble(param_grid, TEST_RANDOM_STATE) + + assert isinstance(result["parameters"], list) + for params in result["parameters"]: + assert "random_state" in params + assert params["random_state"] == TEST_RANDOM_STATE + + param_grid = { + "base_classifier": "SVOREX", + "parameters": {"C": ["0.1"], "kappa": [0.001]}, + } + result = _prepare_parameters_for_ensemble(param_grid, TEST_RANDOM_STATE) + + assert all("random_state" not in d for d in result["parameters"]) + + +def test_prepare_parameters_for_ensemble_literal_eval_fallback(): + """Test that _prepare_parameters_for_ensemble ignores ValueError in literal_eval + and leaves strings.""" + param_grid = { + "base_classifier": "SVC", + "parameters": { + "C": [1], + "gamma": [0.1], + "kernel": ["linear"], + }, + } + + result = _prepare_parameters_for_ensemble(param_grid, TEST_RANDOM_STATE) + assert isinstance(result["parameters"][0]["kernel"], str) + assert result["parameters"][0]["kernel"] == "linear" diff --git a/orca_python/model_selection/validation.py b/orca_python/model_selection/validation.py new file mode 100644 index 0000000..c225c8f --- /dev/null +++ b/orca_python/model_selection/validation.py @@ -0,0 +1,285 @@ +"""Parameter grid preparation and validation utilities for model selection.""" + +from ast import literal_eval +from copy import deepcopy +from itertools import product + +import numpy as np + + +def check_for_random_state(estimator): + """Check if the estimator accepts a random_state parameter. + + Parameters + ---------- + estimator : object + The estimator class to check. + + Returns + ------- + bool + True if the estimator accepts random_state parameter, False otherwise. + + Examples + -------- + >>> from sklearn.svm import SVC, SVR + >>> from sklearn.linear_model import LogisticRegression + >>> check_for_random_state(SVC) + True + >>> check_for_random_state(SVR) + False + >>> check_for_random_state(LogisticRegression) + True + + """ + try: + estimator(random_state=0) + return True + except (TypeError, ValueError): + return False + + +def is_ensemble(param_grid): + """Check if the given parameters correspond to an ensemble method. + + Parameters + ---------- + param_grid : dict + Dictionary defining hyperparameter search space for model optimization. + + Returns + ------- + bool + True if the parameters correspond to an ensemble method, False + otherwise. + + Examples + -------- + >>> is_ensemble({"base_classifier": "SVC"}) + True + >>> is_ensemble({"base_classifier": "SVC", "parameters": {"C": [0.1, 0.2]}}) + True + >>> is_ensemble({"C": 0.1}) + False + + """ + return isinstance(param_grid, dict) and "base_classifier" in param_grid + + +def is_searchcv(param_grid): + """Check if the given parameters require cross-validation search. + + Parameters + ---------- + param_grid : dict + Dictionary defining hyperparameter search space for model optimization. + + Returns + ------- + bool + True if cross-validation is required, False otherwise. + + """ + if not isinstance(param_grid, dict): + return False + + has_search_params = any( + isinstance(value, list) and len(value) > 1 for value in param_grid.values() + ) + + if is_ensemble(param_grid) and "parameters" in param_grid: + base_params = param_grid["parameters"] + + if isinstance(base_params, dict): + base_has_search = any( + isinstance(value, list) and len(value) > 1 + for value in base_params.values() + ) + return has_search_params or base_has_search + else: + if isinstance(base_params, list): + return has_search_params or len(base_params) > 1 + else: + return has_search_params + + return has_search_params + + +def prepare_param_grid(estimator, param_grid, random_state=None): + """This function processes parameter grids to ensure compatibility with + scikit-learn's GridSearchCV and handles special cases like ensemble methods and + random state injection. + + Parameters + ---------- + estimator : object + The estimator to prepare parameters for. + + param_grid : dict + Dictionary defining hyperparameter search space for model optimization. + + random_state : int, RandomState instance or None, optional (default=None) + Seed for reproducible randomization in model training and probability + estimation. + + Returns + ------- + dict + Prepared parameters dictionary. + + Raises + ------ + ValueError + If param_grid is not a dictionary. + + Examples + -------- + >>> from orca_python.model_selection import get_classifier_by_name + >>> estimator = get_classifier_by_name("OrdinalDecomposition") + >>> param_grid = { + ... "dtype": "ordered_partitions", + ... "decision_method": "frank_hall", + ... "base_classifier": "SVC", + ... "parameters": { + ... "C": [0.1, 1.0], + ... "gamma": [0.01, 0.1], + ... "probability": ["True"] + ... } + ... } + >>> prepared_params = prepare_param_grid(estimator, param_grid, random_state=0) + >>> len(prepared_params["parameters"]) + 4 + + """ + if not isinstance(param_grid, dict): + raise ValueError("param_grid must be a dictionary") + + if random_state is None: + random_state = np.random.get_state()[1][0] + + param_grid = deepcopy(param_grid) + param_grid = _add_random_state(estimator, param_grid, random_state) + + if is_ensemble(param_grid): + param_grid["parameters"] = _normalize_param_grid(param_grid["parameters"]) + param_grid = _prepare_parameters_for_ensemble(param_grid, random_state) + + if is_searchcv(param_grid): + for p_name, p_value in param_grid.items(): + if not isinstance(p_value, list) and not isinstance(p_value, dict): + param_grid[p_name] = [p_value] + else: + for p_name, p_value in param_grid.items(): + if isinstance(p_value, list): + param_grid[p_name] = p_value[0] + + return param_grid + + +def _add_random_state(estimator, param_grid, random_state): + """Add random_state to param_grid if the estimator accepts it and it's + not already present. + + Parameters + ---------- + estimator : object + The estimator to add random_state to. + + param_grid : dict + Dictionary defining hyperparameter search space for model optimization. + + random_state : int + Seed for reproducible randomization in model training and probability + estimation. + + Returns + ------- + dict + Prepared parameters dictionary. + + """ + if check_for_random_state(estimator) and "random_state" not in param_grid: + param_grid["random_state"] = random_state + return param_grid + + +def _normalize_param_grid(param_grid): + """Ensure all values in param_grid are lists (for grid search compatibility). + + Parameters + ---------- + param_grid : dict + Dictionary defining hyperparameter search space for model optimization. + + Returns + ------- + dict + Dictionary with all values as lists. + + """ + normalized = {} + for k, v in param_grid.items(): + if isinstance(v, list): + normalized[k] = v + else: + normalized[k] = [v] + return normalized + + +def _prepare_parameters_for_ensemble(param_grid, random_state=None): + """Process the parameters for ensemble methods. + + Parameters + ---------- + param_grid : dict + Dictionary defining hyperparameter search space for model optimization. + + random_state : int, RandomState instance or None + Seed for reproducible randomization in model training and probability + estimation. + + Returns + ------- + dict + Processed parameters dictionary. + + Raises + ------ + TypeError + If all parameters for base_classifier are not lists. + + """ + if not is_ensemble(param_grid): + return param_grid + + from orca_python.model_selection.loaders import get_classifier_by_name + + base_estimator = get_classifier_by_name(param_grid["base_classifier"]) + base_params = param_grid.get("parameters", {}) + + if check_for_random_state(base_estimator): + base_params["random_state"] = [random_state] + + # Creating a list for each parameter. + # Elements represented as 'parameterName;parameterValue'. + param_combinations = [ + [f"{k};{v}" for v in values] + for k, values in _normalize_param_grid(base_params).items() + ] + # Creates a list of dictionaries, containing all + # combinations of given parameters + combinations = [ + dict(item.split(";") for item in combo) + for combo in product(*param_combinations) + ] + + # Returns non-string values back to their normal self + for combination in combinations: + for k, v in combination.items(): + try: + combination[k] = literal_eval(v) + except ValueError: + pass + + param_grid["parameters"] = combinations + return param_grid