Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH Adds class_names to tree.export_text #25387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,14 @@ Changelog
The `sample_interval_` attribute is deprecated and will be removed in 1.5.
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.

:mod:`sklearn.tree`
...................

- |Enhancement| Adds a `class_names` parameter to
:func:`tree.export_text`. This allows specifying the parameter `class_names`
for each target class in ascending numerical order.
:pr:`25387` by :user:`William M <Akbeeh>` and :user:`crispinlogan <crispinlogan>`.

:mod:`sklearn.utils`
....................

Expand Down
22 changes: 21 additions & 1 deletion sklearn/tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def export_text(
decision_tree,
*,
feature_names=None,
class_names=None,
max_depth=10,
spacing=3,
decimals=2,
Expand All @@ -943,6 +944,17 @@ def export_text(
A list of length n_features containing the feature names.
If None generic names will be used ("feature_0", "feature_1", ...).

class_names : list or None, default=None
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.

- if `None`, the class names are delegated to `decision_tree.classes_`;
- if a list, then `class_names` will be used as class names instead
of `decision_tree.classes_`. The length of `class_names` must match
the length of `decision_tree.classes_`.

.. versionadded:: 1.3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a .. versionadded:: 1.3 directive to indicate that the parameter was added in 1.3.

max_depth : int, default=10
Only the first max_depth levels of the tree are exported.
Truncated branches will be marked with "...".
Expand Down Expand Up @@ -986,7 +998,15 @@ def export_text(
check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
if is_classifier(decision_tree):
class_names = decision_tree.classes_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhm it seems that I misunderstood something when reading the documentation at first. We already uses decision_tree.classes_. So we don't need "numeric" and any deprecation (which is a good news).

Sorry to have brought this way. We will need to modify (remove) the code :)

if class_names is None:
class_names = decision_tree.classes_
elif len(class_names) != len(decision_tree.classes_):
raise ValueError(
"When `class_names` is a list, it should contain as"
" many items as `decision_tree.classes_`. Got"
f" {len(class_names)} while the tree was fitted with"
f" {len(decision_tree.classes_)} classes."
)
right_child_fmt = "{} {} <= {}\n"
left_child_fmt = "{} {} > {}\n"
truncation_fmt = "{} {}\n"
Expand Down
17 changes: 17 additions & 0 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ def test_export_text_errors():
err_msg = "feature_names must contain 2 elements, got 1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, feature_names=["a"])
err_msg = (
"When `class_names` is a list, it should contain as"
" many items as `decision_tree.classes_`. Got 1 while"
" the tree was fitted with 2 classes."
)
with pytest.raises(ValueError, match=err_msg):
export_text(clf, class_names=["a"])
err_msg = "decimals must be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, decimals=-1)
Expand Down Expand Up @@ -394,6 +401,16 @@ def test_export_text():
).lstrip()
assert export_text(clf, feature_names=["a", "b"]) == expected_report

expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- class: cat
|--- feature_1 > 0.00
| |--- class: dog
"""
).lstrip()
assert export_text(clf, class_names=["cat", "dog"]) == expected_report

expected_report = dedent(
"""
|--- feature_1 <= 0.00
Expand Down