From 6a88e33c6f6366432d0c030c5dad4dce10ac056e Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Thu, 12 Jan 2023 22:57:41 +0100 Subject: [PATCH 01/10] Update _export.py --- sklearn/tree/_export.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 701a12e1c0174..e209bf056a20e 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -923,6 +923,7 @@ def export_text( decision_tree, *, feature_names=None, + class_names=None, max_depth=10, spacing=3, decimals=2, @@ -943,6 +944,10 @@ def export_text( A list of length n_features containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). + class_names : list of arguments, default=None + Names of each of the target classes in ascending numerical order. + Only relevant for classification and not supported for multi-output. + max_depth : int, default=10 Only the first max_depth levels of the tree are exported. Truncated branches will be marked with "...". @@ -986,7 +991,10 @@ def export_text( check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): - class_names = decision_tree.classes_ + if class_names is not None and len(class_names) == len(decision_tree.classes_): + class_names = class_names + else: + class_names = decision_tree.classes_ right_child_fmt = "{} {} <= {}\n" left_child_fmt = "{} {} > {}\n" truncation_fmt = "{} {}\n" From fd30a477e20210b45bceef14f732221902412b9b Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Fri, 13 Jan 2023 16:42:13 +0100 Subject: [PATCH 02/10] Update code & documentation after reviews --- sklearn/tree/_export.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index e209bf056a20e..f534246925719 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -15,6 +15,7 @@ from numbers import Integral import numpy as np +import warnings from ..utils.validation import check_is_fitted from ..base import is_classifier @@ -944,10 +945,23 @@ def export_text( A list of length n_features containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). - class_names : list of arguments, default=None + class_names : "numeric", list or None, default="numeric" Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. + - if `None`, the class names are delegated to `decision_tree.classes_`; + - if `"numeric"`, the class names are generic names representing numerical + numbers (e.g. `["0", "1", ...]`); + - if a list, the number of items should be the same as in + `decition_tree.classes_` and will be used. + + .. versionadded:: 1.3 + `class_names` was added in version 1.3. + + .. deprecated:: 1.3 + The `"numeric"` option is deprecated and will be replaced by `None`. Thus, + `decision_tree.classes_` will be used by default. + max_depth : int, default=10 Only the first max_depth levels of the tree are exported. Truncated branches will be marked with "...". @@ -991,8 +1005,21 @@ def export_text( check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): - if class_names is not None and len(class_names) == len(decision_tree.classes_): - class_names = class_names + if class_names == "numeric": + warnings.warn( + "The option `class_names='numeric'` is deprecated in 1.3 and will be" + " removed in 1.5. Set `class_names=None`, the classes as seen by" + " `decision_tree` during `fit` will be used instead.", + FutureWarning, + ) + elif class_names is not None and len(class_names) != len( + decision_tree.classes_ + ): + raise ValueError( + "When `class_names` is not None, it should be a list containing as" + f" many items as `decision_tree.classes_`. Got {len(class_names)} while" + f" the tree was fitted with {len(decision_tree.classes_)} classes." + ) else: class_names = decision_tree.classes_ right_child_fmt = "{} {} <= {}\n" @@ -1050,7 +1077,13 @@ def print_tree_recurse(node, depth): value = tree_.value[node][0] else: value = tree_.value[node].T[0] - class_name = np.argmax(value) + + if class_names == "numeric": + class_name = np.argmax(value) + elif class_names is None: + class_name = decision_tree.classes_[np.argmax(value)] + else: + class_name = class_names[np.argmax(value)] if tree_.n_classes[0] != 1 and tree_.n_outputs == 1: class_name = class_names[class_name] From 38b15ffef60a2c5b756ebe17904293d44a39f41c Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Fri, 13 Jan 2023 19:36:53 +0100 Subject: [PATCH 03/10] Update default class_names=numeric --- sklearn/tree/_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index f534246925719..486f7cbff5dbc 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -924,7 +924,7 @@ def export_text( decision_tree, *, feature_names=None, - class_names=None, + class_names="numeric", max_depth=10, spacing=3, decimals=2, From 013fc0e4392cd0079fda34f1e50cb7ba0a973ac8 Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Fri, 13 Jan 2023 23:43:36 +0100 Subject: [PATCH 04/10] Add tests - Need check, raises errors --- sklearn/tree/tests/test_export.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 657860a435c05..b8b98445724e9 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -357,6 +357,13 @@ def test_export_text_errors(): err_msg = "feature_names must contain 2 elements, got 1" with pytest.raises(ValueError, match=err_msg): export_text(clf, feature_names=["a"]) + err_msg = ( + "When `class_names` is not None, it should be a list containing as" + " many items as `decision_tree.classes_`. Got 1 while" + " the tree was fitted with 2 classes." + ) + with pytest.raises(ValueError, match=err_msg): + export_text(clf, class_names=["a"]) err_msg = "decimals must be >= 0, given -1" with pytest.raises(ValueError, match=err_msg): export_text(clf, decimals=-1) @@ -394,6 +401,26 @@ def test_export_text(): ).lstrip() assert export_text(clf, feature_names=["a", "b"]) == expected_report + expected_report = dedent( + """ + |--- feature_1 <= 0.00 + | |--- class: a + |--- feature_1 > 0.00 + | |--- class: b + """ + ).lstrip() + assert export_text(clf, class_names=["a", "b"]) == expected_report + + expected_report = dedent( + """ + |--- feature_1 <= 0.00 + | |--- class: -1 + |--- feature_1 > 0.00 + | |--- class: 1 + """ + ).lstrip() + assert export_text(clf, class_names=None) == expected_report + expected_report = dedent( """ |--- feature_1 <= 0.00 From 39cf9e214db6b9e5dc2a6d8dafe5cb5b2f4f086a Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Fri, 20 Jan 2023 17:29:45 +0100 Subject: [PATCH 05/10] Add tests / Apply changes --- sklearn/tree/_export.py | 26 +++++++------- sklearn/tree/tests/test_export.py | 58 +++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 486f7cbff5dbc..9d74502222fc8 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -1012,14 +1012,17 @@ def export_text( " `decision_tree` during `fit` will be used instead.", FutureWarning, ) - elif class_names is not None and len(class_names) != len( - decision_tree.classes_ - ): - raise ValueError( - "When `class_names` is not None, it should be a list containing as" - f" many items as `decision_tree.classes_`. Got {len(class_names)} while" - f" the tree was fitted with {len(decision_tree.classes_)} classes." - ) + class_names = range(decision_tree.n_classes_) + elif class_names is not None: + if len(class_names) != len(decision_tree.classes_): + raise ValueError( + "When `class_names` is not None, it should be a list containing as" + " many items as `decision_tree.classes_`. Got" + f" {len(class_names)} while the tree was fitted with" + f" {len(decision_tree.classes_)} classes." + ) + else: + class_names = class_names else: class_names = decision_tree.classes_ right_child_fmt = "{} {} <= {}\n" @@ -1078,12 +1081,7 @@ def print_tree_recurse(node, depth): else: value = tree_.value[node].T[0] - if class_names == "numeric": - class_name = np.argmax(value) - elif class_names is None: - class_name = decision_tree.classes_[np.argmax(value)] - else: - class_name = class_names[np.argmax(value)] + class_name = np.argmax(value) if tree_.n_classes[0] != 1 and tree_.n_outputs == 1: class_name = class_names[class_name] diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index b8b98445724e9..4455f32fe4795 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -347,6 +347,18 @@ def test_precision(): assert len(search(r"\.\d+", finding.group()).group()) == precision + 1 +def test_export_text_warnings(): + clf = DecisionTreeClassifier(max_depth=2, random_state=0) + clf.fit(X, y) + warn_msg = ( + "The option `class_names='numeric'` is deprecated in 1.3 and will be" + " removed in 1.5. Set `class_names=None`, the classes as seen by" + " `decision_tree` during `fit` will be used instead." + ) + with pytest.warns(FutureWarning, match=warn_msg): + export_text(clf, class_names="numeric") + + def test_export_text_errors(): clf = DecisionTreeClassifier(max_depth=2, random_state=0) clf.fit(X, y) @@ -385,11 +397,21 @@ def test_export_text(): """ ).lstrip() - assert export_text(clf) == expected_report + assert export_text(clf, class_names=None) == expected_report # testing that leaves at level 1 are not truncated - assert export_text(clf, max_depth=0) == expected_report + assert export_text(clf, class_names=None, max_depth=0) == expected_report # testing that the rest of the tree is truncated - assert export_text(clf, max_depth=10) == expected_report + assert export_text(clf, class_names=None, max_depth=10) == expected_report + + expected_report = dedent( + """ + |--- feature_1 <= 0.00 + | |--- class: 0 + |--- feature_1 > 0.00 + | |--- class: 1 + """ + ).lstrip() + assert export_text(clf, class_names="numeric") == expected_report expected_report = dedent( """ @@ -399,7 +421,9 @@ def test_export_text(): | |--- class: 1 """ ).lstrip() - assert export_text(clf, feature_names=["a", "b"]) == expected_report + assert ( + export_text(clf, feature_names=["a", "b"], class_names=None) == expected_report + ) expected_report = dedent( """ @@ -429,7 +453,7 @@ def test_export_text(): | |--- weights: [0.00, 3.00] class: 1 """ ).lstrip() - assert export_text(clf, show_weights=True) == expected_report + assert export_text(clf, class_names=None, show_weights=True) == expected_report expected_report = dedent( """ @@ -439,7 +463,7 @@ def test_export_text(): | |- class: 1 """ ).lstrip() - assert export_text(clf, spacing=1) == expected_report + assert export_text(clf, class_names=None, spacing=1) == expected_report X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]] y_l = [-1, -1, -1, 1, 1, 1, 2] @@ -453,7 +477,7 @@ def test_export_text(): | |--- truncated branch of depth 2 """ ).lstrip() - assert export_text(clf, max_depth=0) == expected_report + assert export_text(clf, class_names=None, max_depth=0) == expected_report X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]] @@ -469,8 +493,11 @@ def test_export_text(): | |--- value: [1.0, 1.0] """ ).lstrip() - assert export_text(reg, decimals=1) == expected_report - assert export_text(reg, decimals=1, show_weights=True) == expected_report + assert export_text(reg, class_names=None, decimals=1) == expected_report + assert ( + export_text(reg, class_names=None, decimals=1, show_weights=True) + == expected_report + ) X_single = [[-2], [-1], [-1], [1], [1], [2]] reg = DecisionTreeRegressor(max_depth=2, random_state=0) @@ -484,9 +511,18 @@ def test_export_text(): | |--- value: [1.0, 1.0] """ ).lstrip() - assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report assert ( - export_text(reg, decimals=1, show_weights=True, feature_names=["first"]) + export_text(reg, decimals=1, feature_names=["first"], class_names=None) + == expected_report + ) + assert ( + export_text( + reg, + decimals=1, + show_weights=True, + feature_names=["first"], + class_names=None, + ) == expected_report ) From a720c9386cc09daf73c1d8181a166b1f001322eb Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Mon, 23 Jan 2023 11:49:19 +0100 Subject: [PATCH 06/10] Changes - Add description in RST file --- doc/whats_new/v1.3.rst | 9 ++++ sklearn/tree/_export.py | 25 ++---------- sklearn/tree/tests/test_export.py | 68 +++++-------------------------- 3 files changed, 23 insertions(+), 79 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 3b24ac2a35c39..1a9ca10cd6bb0 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -141,6 +141,15 @@ Changelog :mod:`sklearn.preprocessing` ............................ +- |Enhancement| Adds a `class_names` parameter to + :func:`tree.export_text`. This allows specifying the parameter `class_names` + for each target class in ascending numerical order. + :pr:`25387` by :user:`William M `, :user:`Guillaume Lemaitre `, and + :user:`crispinlogan `. + +:mod:`sklearn.tree` +................... + - |Enhancement| Adds a `feature_name_combiner` parameter to :class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create feature names to be returned by :meth:`get_feature_names_out`. diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 9d74502222fc8..8bad6aa0e1351 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -15,7 +15,6 @@ from numbers import Integral import numpy as np -import warnings from ..utils.validation import check_is_fitted from ..base import is_classifier @@ -924,7 +923,7 @@ def export_text( decision_tree, *, feature_names=None, - class_names="numeric", + class_names=None, max_depth=10, spacing=3, decimals=2, @@ -945,23 +944,14 @@ def export_text( A list of length n_features containing the feature names. If None generic names will be used ("feature_0", "feature_1", ...). - class_names : "numeric", list or None, default="numeric" + class_names : list or None, default=None Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. - if `None`, the class names are delegated to `decision_tree.classes_`; - - if `"numeric"`, the class names are generic names representing numerical - numbers (e.g. `["0", "1", ...]`); - if a list, the number of items should be the same as in `decition_tree.classes_` and will be used. - .. versionadded:: 1.3 - `class_names` was added in version 1.3. - - .. deprecated:: 1.3 - The `"numeric"` option is deprecated and will be replaced by `None`. Thus, - `decision_tree.classes_` will be used by default. - max_depth : int, default=10 Only the first max_depth levels of the tree are exported. Truncated branches will be marked with "...". @@ -1005,15 +995,7 @@ def export_text( check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): - if class_names == "numeric": - warnings.warn( - "The option `class_names='numeric'` is deprecated in 1.3 and will be" - " removed in 1.5. Set `class_names=None`, the classes as seen by" - " `decision_tree` during `fit` will be used instead.", - FutureWarning, - ) - class_names = range(decision_tree.n_classes_) - elif class_names is not None: + if class_names is not None: if len(class_names) != len(decision_tree.classes_): raise ValueError( "When `class_names` is not None, it should be a list containing as" @@ -1080,7 +1062,6 @@ def print_tree_recurse(node, depth): value = tree_.value[node][0] else: value = tree_.value[node].T[0] - class_name = np.argmax(value) if tree_.n_classes[0] != 1 and tree_.n_outputs == 1: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 4455f32fe4795..c5ce08a480d6b 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -347,18 +347,6 @@ def test_precision(): assert len(search(r"\.\d+", finding.group()).group()) == precision + 1 -def test_export_text_warnings(): - clf = DecisionTreeClassifier(max_depth=2, random_state=0) - clf.fit(X, y) - warn_msg = ( - "The option `class_names='numeric'` is deprecated in 1.3 and will be" - " removed in 1.5. Set `class_names=None`, the classes as seen by" - " `decision_tree` during `fit` will be used instead." - ) - with pytest.warns(FutureWarning, match=warn_msg): - export_text(clf, class_names="numeric") - - def test_export_text_errors(): clf = DecisionTreeClassifier(max_depth=2, random_state=0) clf.fit(X, y) @@ -397,21 +385,11 @@ def test_export_text(): """ ).lstrip() - assert export_text(clf, class_names=None) == expected_report + assert export_text(clf) == expected_report # testing that leaves at level 1 are not truncated - assert export_text(clf, class_names=None, max_depth=0) == expected_report + assert export_text(clf, max_depth=0) == expected_report # testing that the rest of the tree is truncated - assert export_text(clf, class_names=None, max_depth=10) == expected_report - - expected_report = dedent( - """ - |--- feature_1 <= 0.00 - | |--- class: 0 - |--- feature_1 > 0.00 - | |--- class: 1 - """ - ).lstrip() - assert export_text(clf, class_names="numeric") == expected_report + assert export_text(clf, max_depth=10) == expected_report expected_report = dedent( """ @@ -421,9 +399,7 @@ def test_export_text(): | |--- class: 1 """ ).lstrip() - assert ( - export_text(clf, feature_names=["a", "b"], class_names=None) == expected_report - ) + assert export_text(clf, feature_names=["a", "b"]) == expected_report expected_report = dedent( """ @@ -435,16 +411,6 @@ def test_export_text(): ).lstrip() assert export_text(clf, class_names=["a", "b"]) == expected_report - expected_report = dedent( - """ - |--- feature_1 <= 0.00 - | |--- class: -1 - |--- feature_1 > 0.00 - | |--- class: 1 - """ - ).lstrip() - assert export_text(clf, class_names=None) == expected_report - expected_report = dedent( """ |--- feature_1 <= 0.00 @@ -453,7 +419,7 @@ def test_export_text(): | |--- weights: [0.00, 3.00] class: 1 """ ).lstrip() - assert export_text(clf, class_names=None, show_weights=True) == expected_report + assert export_text(clf, show_weights=True) == expected_report expected_report = dedent( """ @@ -463,7 +429,7 @@ def test_export_text(): | |- class: 1 """ ).lstrip() - assert export_text(clf, class_names=None, spacing=1) == expected_report + assert export_text(clf, spacing=1) == expected_report X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]] y_l = [-1, -1, -1, 1, 1, 1, 2] @@ -477,7 +443,7 @@ def test_export_text(): | |--- truncated branch of depth 2 """ ).lstrip() - assert export_text(clf, class_names=None, max_depth=0) == expected_report + assert export_text(clf, max_depth=0) == expected_report X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]] @@ -493,11 +459,8 @@ def test_export_text(): | |--- value: [1.0, 1.0] """ ).lstrip() - assert export_text(reg, class_names=None, decimals=1) == expected_report - assert ( - export_text(reg, class_names=None, decimals=1, show_weights=True) - == expected_report - ) + assert export_text(reg, decimals=1) == expected_report + assert export_text(reg, decimals=1, show_weights=True) == expected_report X_single = [[-2], [-1], [-1], [1], [1], [2]] reg = DecisionTreeRegressor(max_depth=2, random_state=0) @@ -511,18 +474,9 @@ def test_export_text(): | |--- value: [1.0, 1.0] """ ).lstrip() + assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report assert ( - export_text(reg, decimals=1, feature_names=["first"], class_names=None) - == expected_report - ) - assert ( - export_text( - reg, - decimals=1, - show_weights=True, - feature_names=["first"], - class_names=None, - ) + export_text(reg, decimals=1, show_weights=True, feature_names=["first"]) == expected_report ) From 1803d3cf63f1ff98e922f640efd705cc4f78f6aa Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Mon, 23 Jan 2023 12:56:24 +0100 Subject: [PATCH 07/10] Small change, remove else statement after raise ValueError --- sklearn/tree/_export.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 8bad6aa0e1351..19f7fcb36312e 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -1003,8 +1003,7 @@ def export_text( f" {len(class_names)} while the tree was fitted with" f" {len(decision_tree.classes_)} classes." ) - else: - class_names = class_names + class_names = class_names else: class_names = decision_tree.classes_ right_child_fmt = "{} {} <= {}\n" From 7f14483da415d6368b3257c81f55b7b529a8f67c Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Tue, 24 Jan 2023 12:33:24 +0100 Subject: [PATCH 08/10] Small changes --- doc/whats_new/v1.3.rst | 17 ++++++++--------- sklearn/tree/_export.py | 18 ++++++++---------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 1a3c8e07d6d67..a50a8d95dd46e 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -145,15 +145,6 @@ Changelog :mod:`sklearn.preprocessing` ............................ -- |Enhancement| Adds a `class_names` parameter to - :func:`tree.export_text`. This allows specifying the parameter `class_names` - for each target class in ascending numerical order. - :pr:`25387` by :user:`William M `, :user:`Guillaume Lemaitre `, and - :user:`crispinlogan `. - -:mod:`sklearn.tree` -................... - - |Enhancement| Adds a `feature_name_combiner` parameter to :class:`preprocessing.OneHotEncoder`. This specifies a custom callable to create feature names to be returned by :meth:`get_feature_names_out`. @@ -171,6 +162,14 @@ Changelog The `sample_interval_` attribute is deprecated and will be removed in 1.5. :pr:`25190` by :user:`Vincent Maladière `. +:mod:`sklearn.tree` +................... + +- |Enhancement| Adds a `class_names` parameter to + :func:`tree.export_text`. This allows specifying the parameter `class_names` + for each target class in ascending numerical order. + :pr:`25387` by :user:`William M ` and :user:`crispinlogan `. + :mod:`sklearn.utils` .................... diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 19f7fcb36312e..781c6a4001376 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -995,16 +995,14 @@ def export_text( check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): - if class_names is not None: - if len(class_names) != len(decision_tree.classes_): - raise ValueError( - "When `class_names` is not None, it should be a list containing as" - " many items as `decision_tree.classes_`. Got" - f" {len(class_names)} while the tree was fitted with" - f" {len(decision_tree.classes_)} classes." - ) - class_names = class_names - else: + if class_names is not None and len(class_names) != len(decision_tree.classes_): + raise ValueError( + "When `class_names` is not None, it should be a list containing as" + " many items as `decision_tree.classes_`. Got" + f" {len(class_names)} while the tree was fitted with" + f" {len(decision_tree.classes_)} classes." + ) + elif class_names is None: class_names = decision_tree.classes_ right_child_fmt = "{} {} <= {}\n" left_child_fmt = "{} {} > {}\n" From 7e0115042fcc0e5fc4e09f46191bc955fb1a924b Mon Sep 17 00:00:00 2001 From: Akbeeh Date: Wed, 8 Feb 2023 18:56:16 +0100 Subject: [PATCH 09/10] Changes made --- sklearn/tree/_export.py | 10 ++++++---- sklearn/tree/tests/test_export.py | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 781c6a4001376..d78bf1e23e41c 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -952,6 +952,8 @@ def export_text( - if a list, the number of items should be the same as in `decition_tree.classes_` and will be used. + .. versionadded:: 1.3 + max_depth : int, default=10 Only the first max_depth levels of the tree are exported. Truncated branches will be marked with "...". @@ -995,15 +997,15 @@ def export_text( check_is_fitted(decision_tree) tree_ = decision_tree.tree_ if is_classifier(decision_tree): - if class_names is not None and len(class_names) != len(decision_tree.classes_): + if class_names is None: + class_names = decision_tree.classes_ + elif len(class_names) != len(decision_tree.classes_): raise ValueError( - "When `class_names` is not None, it should be a list containing as" + "When `class_names` is a list, it should contain as" " many items as `decision_tree.classes_`. Got" f" {len(class_names)} while the tree was fitted with" f" {len(decision_tree.classes_)} classes." ) - elif class_names is None: - class_names = decision_tree.classes_ right_child_fmt = "{} {} <= {}\n" left_child_fmt = "{} {} > {}\n" truncation_fmt = "{} {}\n" diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index c5ce08a480d6b..8865cb724a02a 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -358,7 +358,7 @@ def test_export_text_errors(): with pytest.raises(ValueError, match=err_msg): export_text(clf, feature_names=["a"]) err_msg = ( - "When `class_names` is not None, it should be a list containing as" + "When `class_names` is a list, it should contain as" " many items as `decision_tree.classes_`. Got 1 while" " the tree was fitted with 2 classes." ) @@ -404,12 +404,12 @@ def test_export_text(): expected_report = dedent( """ |--- feature_1 <= 0.00 - | |--- class: a + | |--- class: cat |--- feature_1 > 0.00 - | |--- class: b + | |--- class: dog """ ).lstrip() - assert export_text(clf, class_names=["a", "b"]) == expected_report + assert export_text(clf, class_names=["cat", "dog"]) == expected_report expected_report = dedent( """ From 862684438a5c6fc44522860efeece12369f7d157 Mon Sep 17 00:00:00 2001 From: William M <64324808+Akbeeh@users.noreply.github.com> Date: Wed, 8 Feb 2023 21:00:54 +0100 Subject: [PATCH 10/10] Update: rewording Co-authored-by: Thomas J. Fan --- sklearn/tree/_export.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index d78bf1e23e41c..3e65c4a2b0dc5 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -949,8 +949,9 @@ def export_text( Only relevant for classification and not supported for multi-output. - if `None`, the class names are delegated to `decision_tree.classes_`; - - if a list, the number of items should be the same as in - `decition_tree.classes_` and will be used. + - if a list, then `class_names` will be used as class names instead + of `decision_tree.classes_`. The length of `class_names` must match + the length of `decision_tree.classes_`. .. versionadded:: 1.3