diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index a083d8600fe4b..08ebf4abc92c3 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -206,6 +206,14 @@ Changelog The `sample_interval_` attribute is deprecated and will be removed in 1.5. :pr:`25190` by :user:`Vincent Maladière `. +:mod:`sklearn.tree` +................... + +- |Enhancement| Adds a `class_names` parameter to + :func:`tree.export_text`. This allows specifying the parameter `class_names` + for each target class in ascending numerical order. + :pr:`25387` by :user:`William M ` and :user:`crispinlogan `. + :mod:`sklearn.utils` .................... diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 701a12e1c0174..3e65c4a2b0dc5 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -923,6 +923,7 @@ def export_text( decision_tree, *, feature_names=None, + class_names=None, max_depth=10, spacing=3, decimals=2, @@ -943,6 +944,17 @@ def export_text( A list of length n_features containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). + class_names : list or None, default=None + Names of each of the target classes in ascending numerical order. + Only relevant for classification and not supported for multi-output. + + - if `None`, the class names are delegated to `decision_tree.classes_`; + - if a list, then `class_names` will be used as class names instead + of `decision_tree.classes_`. The length of `class_names` must match + the length of `decision_tree.classes_`. + + .. versionadded:: 1.3 + max_depth : int, default=10 Only the first max_depth levels of the tree are exported. Truncated branches will be marked with "...". @@ -986,7 +998,15 @@ def export_text( check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): - class_names = decision_tree.classes_ + if class_names is None: + class_names = decision_tree.classes_ + elif len(class_names) != len(decision_tree.classes_): + raise ValueError( + "When `class_names` is a list, it should contain as" + " many items as `decision_tree.classes_`. Got" + f" {len(class_names)} while the tree was fitted with" + f" {len(decision_tree.classes_)} classes." + ) right_child_fmt = "{} {} <= {}\n" left_child_fmt = "{} {} > {}\n" truncation_fmt = "{} {}\n" diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 657860a435c05..8865cb724a02a 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -357,6 +357,13 @@ def test_export_text_errors(): err_msg = "feature_names must contain 2 elements, got 1" with pytest.raises(ValueError, match=err_msg): export_text(clf, feature_names=["a"]) + err_msg = ( + "When `class_names` is a list, it should contain as" + " many items as `decision_tree.classes_`. Got 1 while" + " the tree was fitted with 2 classes." + ) + with pytest.raises(ValueError, match=err_msg): + export_text(clf, class_names=["a"]) err_msg = "decimals must be >= 0, given -1" with pytest.raises(ValueError, match=err_msg): export_text(clf, decimals=-1) @@ -394,6 +401,16 @@ def test_export_text(): ).lstrip() assert export_text(clf, feature_names=["a", "b"]) == expected_report + expected_report = dedent( + """ + |--- feature_1 <= 0.00 + | |--- class: cat + |--- feature_1 > 0.00 + | |--- class: dog + """ + ).lstrip() + assert export_text(clf, class_names=["cat", "dog"]) == expected_report + expected_report = dedent( """ |--- feature_1 <= 0.00