Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions orca_python/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Model selection and estimator loading utilities."""

from .loaders import get_classifier_by_name, load_classifier
from .validation import (
check_for_random_state,
is_ensemble,
is_searchcv,
prepare_param_grid,
)

__all__ = [
"get_classifier_by_name",
"load_classifier",
"check_for_random_state",
"is_ensemble",
"is_searchcv",
"prepare_param_grid",
]
158 changes: 158 additions & 0 deletions orca_python/model_selection/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Model selection and estimator loading utilities."""

from importlib import import_module

from sklearn.model_selection import GridSearchCV, StratifiedKFold

from orca_python.metrics.utils import load_metric_as_scorer
from orca_python.model_selection.validation import (
is_searchcv,
prepare_param_grid,
)

_ORCA_CLASSIFIERS = {
"NNOP": "orca_python.classifiers.NNOP",
"NNPOM": "orca_python.classifiers.NNPOM",
"OrdinalDecomposition": "orca_python.classifiers.OrdinalDecomposition",
"REDSVM": "orca_python.classifiers.REDSVM",
"SVOREX": "orca_python.classifiers.SVOREX",
}

_SKLEARN_CLASSIFIERS = {
"SVC": "sklearn.svm.SVC",
"LogisticRegression": "sklearn.linear_model.LogisticRegression",
"RandomForestClassifier": "sklearn.ensemble.RandomForestClassifier",
}

_CLASSIFIERS = {**_ORCA_CLASSIFIERS, **_SKLEARN_CLASSIFIERS}


def get_classifier_by_name(classifier_name):
"""Return a classifier not instantiated matching a given input name.

Parameters
----------
classifier_name : str
Name of the classification algorithm being employed.

Returns
-------
classifier : object
Returns a classifier, either from a scikit-learn module, or from a
module of this framework.

Raises
------
ValueError
If an unknown classifier name is provided.

Examples
--------
>>> get_classifier_by_name("SVOREX")
<class 'orca_python.classifiers.SVOREX.SVOREX'>
>>> get_classifier_by_name("REDSVM")
<class 'orca_python.classifiers.REDSVM.REDSVM'>
>>> get_classifier_by_name("SVC")
<class 'sklearn.svm._classes.SVC'>

"""
if classifier_name not in _CLASSIFIERS:
raise ValueError(f"Unknown classifier '{classifier_name}'.")

module_path, class_name = _CLASSIFIERS[classifier_name].rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)


def load_classifier(
classifier_name,
random_state=None,
n_jobs=1,
cv_n_folds=3,
cv_metric="mae",
param_grid=None,
):
"""Return a fully configured classifier, optionally with cross-validation.

This function loads a classifier, configures its parameters, and optionally
sets up cross-validation if multiple parameter values are provided.

Parameters
----------
classifier_name : str
Name of the classification algorithm being employed.

random_state : int, RandomState instance or None, optional (default=None)
Seed for reproducible randomization in model training and probability
estimation.

n_jobs : int, optional (default=1)
Number of parallel processing cores for computational tasks.

cv_n_folds : int, optional (default=3)
Number of folds for cross-validation (only used if applicable).

cv_metric : str or callable, optional (default="mae")
Evaluation metric for cross-validation performance assessment.

param_grid : dict or None, optional (default=None)
Hyperparameter grid. If multiple values are given, cross-validation will be applied.

Returns
-------
classifier : object
The initialized classifier object, optionally wrapped in GridSearchCV.

Raises
------
ValueError
If an unknown classifier name is provided or if an invalid parameter
is specified for the classifier.

Examples
--------
>>> from orca_python.model_selection import load_classifier
>>> clf = load_classifier("SVC", random_state=0)
>>> clf
SVC()
>>> clf_cv = load_classifier("SVC", random_state=0, param_grid={"C": [0.1, 1.0]})
>>> clf_cv.__class__.__name__
'GridSearchCV'

"""
classifier_cls = get_classifier_by_name(classifier_name)

if param_grid is None:
return classifier_cls()

param_grid = prepare_param_grid(classifier_cls, param_grid, random_state)

if is_searchcv(param_grid):
scorer = (
load_metric_as_scorer(cv_metric)
if isinstance(cv_metric, str)
else cv_metric
)
cv = StratifiedKFold(
n_splits=cv_n_folds, shuffle=True, random_state=random_state
)

return GridSearchCV(
estimator=classifier_cls(),
param_grid=param_grid,
scoring=scorer,
n_jobs=n_jobs,
cv=cv,
error_score="raise",
)

try:
classifier = classifier_cls(**param_grid)
classifier.assigned_params_ = param_grid
return classifier
except TypeError as e:
invalid_param = str(e).split("'")[1]
raise ValueError(
f"Invalid parameter '{invalid_param}' for classifier"
f" '{classifier_name}'."
)
3 changes: 3 additions & 0 deletions orca_python/model_selection/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Tests for model selection module."""

__all__ = []
123 changes: 123 additions & 0 deletions orca_python/model_selection/tests/test_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"Tests for model selection and estimator loading utilities."

import pytest
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

from orca_python.classifiers import NNOP, NNPOM, REDSVM, SVOREX, OrdinalDecomposition
from orca_python.metrics import load_metric_as_scorer
from orca_python.model_selection import get_classifier_by_name, load_classifier
from orca_python.testing import TEST_RANDOM_STATE


def test_get_classifier_by_name_correct():
"""Test that get_classifier_by_name returns the correct classifier."""
# ORCA classifiers
assert get_classifier_by_name("NNOP") == NNOP
assert get_classifier_by_name("NNPOM") == NNPOM
assert get_classifier_by_name("OrdinalDecomposition") == OrdinalDecomposition
assert get_classifier_by_name("REDSVM") == REDSVM
assert get_classifier_by_name("SVOREX") == SVOREX

# Scikit-learn classifiers
assert get_classifier_by_name("SVC") == SVC
assert get_classifier_by_name("LogisticRegression") == LogisticRegression


def test_get_classifier_by_name_incorrect():
"""Test that get_classifier_by_name raises ValueError for unknown classifiers."""
with pytest.raises(ValueError, match="Unknown classifier 'RandomForest'"):
get_classifier_by_name("RandomForest")

with pytest.raises(ValueError, match="Unknown classifier 'SVR'"):
get_classifier_by_name("SVR")


def test_load_classifier_without_parameters():
"""Test that load_classifier correctly instantiates classifiers without
parameters."""
assert isinstance(load_classifier("NNOP"), NNOP)
assert isinstance(load_classifier("NNPOM"), NNPOM)
assert isinstance(load_classifier("OrdinalDecomposition"), OrdinalDecomposition)
assert isinstance(load_classifier("REDSVM"), REDSVM)
assert isinstance(load_classifier("SVOREX"), SVOREX)
assert isinstance(load_classifier("SVC"), SVC)
assert isinstance(load_classifier("LogisticRegression"), LogisticRegression)


def test_load_classifier_with_parameters():
"""Test that load_classifier correctly instantiates classifiers with
parameters."""
param_grid = {
"epsilon_init": 0.5,
"n_hidden": 10,
"max_iter": 500,
"lambda_value": 0.01,
}
classifier = load_classifier("NNPOM", param_grid=param_grid)
assert isinstance(classifier, NNPOM)
assert classifier.epsilon_init == 0.5
assert classifier.n_hidden == 10
assert classifier.max_iter == 500
assert classifier.lambda_value == 0.01


def test_load_classifier_with_searchcv():
"""Test that load_classifier correctly returns a GridSearchCV when param_grid has multiple values."""
param_grid = {"C": [0.1, 1.0], "probability": "True"}

classifier = load_classifier(
"SVC",
param_grid=param_grid,
random_state=TEST_RANDOM_STATE,
cv_n_folds=5,
cv_metric="mae",
n_jobs=1,
)

expected_param_grid = {
"C": [0.1, 1.0],
"probability": ["True"],
"random_state": [TEST_RANDOM_STATE],
}

assert isinstance(classifier, GridSearchCV)
assert classifier.cv.n_splits == 5
assert classifier.param_grid == expected_param_grid


def test_load_classifier_with_ensemble_method():
"""Test that load_classifier correctly handles ensemble methods."""
param_grid = {
"dtype": "ordered_partitions",
"decision_method": "frank_hall",
"base_classifier": "SVC",
"parameters": {
"C": [0.01, 0.1, 1, 10],
"gamma": [0.01, 0.1, 1, 10],
"probability": ["True"],
},
}
classifier = load_classifier(
classifier_name="OrdinalDecomposition",
param_grid=param_grid,
n_jobs=10,
cv_n_folds=3,
cv_metric=load_metric_as_scorer("mae"),
random_state=TEST_RANDOM_STATE,
)
assert isinstance(classifier, GridSearchCV)
assert classifier.param_grid["decision_method"] == [param_grid["decision_method"]]
assert classifier.param_grid["base_classifier"] == [param_grid["base_classifier"]]
for params in classifier.param_grid["parameters"]:
assert params["random_state"] == TEST_RANDOM_STATE
assert classifier.cv.n_splits == 3


def test_load_classifier_with_invalid_param():
"""Test that load_classifier raises error with invalid parameter key."""
error_msg = "Invalid parameter 'T' for classifier 'SVC'."

with pytest.raises(ValueError, match=error_msg):
load_classifier(classifier_name="SVC", param_grid={"T": 0.1})
Loading