From 5e949f855b83f41201d794401dbce32767a41b7b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 27 May 2022 13:20:48 -0400 Subject: [PATCH 1/2] ENH Uses more didactic x[i] to represent features --- doc/whats_new/v1.2.rst | 3 +++ sklearn/tree/_export.py | 6 +++--- sklearn/tree/tests/test_export.py | 16 ++++++++-------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index b4804f3c9c8b9..610094a8d7c8b 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -111,6 +111,9 @@ Changelog :mod:`sklearn.tree` ................... +- |Enhancement| :func:`tree.plot_tree`, :func:`tree.export_graphviz` now uses + a lower case `x[i]` to represent feature `i`. :pr:`xxxxx` by `Thomas Fan`_. + - |Fix| Fixed invalid memory access bug during fit in :class:`tree.DecisionTreeRegressor` and :class:`tree.DecisionTreeClassifier`. :pr:`23273` by `Thomas Fan`_. diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 4e2e8b58cc370..e3be3795e16e8 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -115,7 +115,7 @@ def plot_tree( feature_names : list of strings, default=None Names of each of the features. - If None, generic names will be used ("X[0]", "X[1]", ...). + If None, generic names will be used ("x[0]", "x[1]", ...). class_names : list of str or bool, default=None Names of each of the target classes in ascending numerical order. @@ -291,7 +291,7 @@ def node_to_str(self, tree, node_id, criterion): if self.feature_names is not None: feature = self.feature_names[tree.feature[node_id]] else: - feature = "X%s%s%s" % ( + feature = "x%s%s%s" % ( characters[1], tree.feature[node_id], characters[2], @@ -789,7 +789,7 @@ def export_graphviz( feature_names : list of str, default=None Names of each of the features. - If None generic names will be used ("feature_0", "feature_1", ...). + If None, generic names will be used ("x[0]", "x[1]", ...). class_names : list of str or bool, default=None Names of each of the target classes in ascending numerical order. diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index d3b082a927048..657860a435c05 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -35,7 +35,7 @@ def test_graphviz_toy(): "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' 'edge [fontname="helvetica"] ;\n' - '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' 'value = [3, 3]"] ;\n' '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' "0 -> 1 [labeldistance=2.5, labelangle=45, " @@ -75,7 +75,7 @@ def test_graphviz_toy(): "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' 'edge [fontname="helvetica"] ;\n' - '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' 'value = [3, 3]\\nclass = yes"] ;\n' '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' 'class = yes"] ;\n' @@ -106,7 +106,7 @@ def test_graphviz_toy(): 'node [shape=box, style="filled, rounded", color="black", ' 'fontname="sans"] ;\n' 'edge [fontname="sans"] ;\n' - "0 [label=0 ≤ 0.0
samples = 100.0%
" + "0 [label=0 ≤ 0.0
samples = 100.0%
" 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' "1 [label=value = [1.0, 0.0]>, " 'fillcolor="#e58139"] ;\n' @@ -127,7 +127,7 @@ def test_graphviz_toy(): "digraph Tree {\n" 'node [shape=box, fontname="helvetica"] ;\n' 'edge [fontname="helvetica"] ;\n' - '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' + '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' 'value = [3, 3]\\nclass = y[0]"] ;\n' '1 [label="(...)"] ;\n' "0 -> 1 ;\n" @@ -147,7 +147,7 @@ def test_graphviz_toy(): 'node [shape=box, style="filled", color="black", ' 'fontname="helvetica"] ;\n' 'edge [fontname="helvetica"] ;\n' - '0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' + '0 [label="node #0\\nx[0] <= 0.0\\ngini = 0.5\\n' 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' '1 [label="(...)", fillcolor="#C0C0C0"] ;\n' "0 -> 1 ;\n" @@ -170,14 +170,14 @@ def test_graphviz_toy(): 'node [shape=box, style="filled", color="black", ' 'fontname="helvetica"] ;\n' 'edge [fontname="helvetica"] ;\n' - '0 [label="X[0] <= 0.0\\nsamples = 6\\n' + '0 [label="x[0] <= 0.0\\nsamples = 6\\n' "value = [[3.0, 1.5, 0.0]\\n" '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' '1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' '[3, 0, 0]]", fillcolor="#e58139"] ;\n' "0 -> 1 [labeldistance=2.5, labelangle=45, " 'headlabel="True"] ;\n' - '2 [label="X[0] <= 1.5\\nsamples = 3\\n' + '2 [label="x[0] <= 1.5\\nsamples = 3\\n' "value = [[0.0, 1.5, 0.0]\\n" '[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' "0 -> 2 [labeldistance=2.5, labelangle=-45, " @@ -215,7 +215,7 @@ def test_graphviz_toy(): "graph [ranksep=equally, splines=polyline] ;\n" 'edge [fontname="sans"] ;\n' "rankdir=LR ;\n" - '0 [label="X[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n' + '0 [label="x[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n' 'value = 0.0", fillcolor="#f2c09c"] ;\n' '1 [label="squared_error = 0.0\\nsamples = 3\\' 'nvalue = -1.0", ' From a843ec28db4478d24402863fcad3acd1b187c24f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 27 May 2022 13:35:23 -0400 Subject: [PATCH 2/2] DOC Adds PR number --- doc/whats_new/v1.2.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 610094a8d7c8b..8778e74a011a3 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -112,7 +112,7 @@ Changelog ................... - |Enhancement| :func:`tree.plot_tree`, :func:`tree.export_graphviz` now uses - a lower case `x[i]` to represent feature `i`. :pr:`xxxxx` by `Thomas Fan`_. + a lower case `x[i]` to represent feature `i`. :pr:`23480` by `Thomas Fan`_. - |Fix| Fixed invalid memory access bug during fit in :class:`tree.DecisionTreeRegressor` and :class:`tree.DecisionTreeClassifier`.