-
Notifications
You must be signed in to change notification settings - Fork 55
Description
Running the code snippet from the README. Using numpy <=1.26.4 tree.print_tree() returns
([], {1: 5.0, 2: 5.0}, (a, p=0.0015654022580025482, score=10.0, groups=[[1], [2]]), dof=1))
|-- ([1], {1: 5.0, 2: 0}, - the minimum parent node size threshold has been reached)
+-- ([2], {1: 0, 2: 5.0}, - the minimum parent node size threshold has been reached)
However using numpy >=2.0.0 tree.print_tree() returns
([], {np.int64(1): np.float64(5.0), np.int64(2): np.float64(5.0)}, (a, p=0.0015654022580025482, score=10.0, groups=[[np.int64(1)], [np.int64(2)]]), dof=1))
|-- ([np.int64(1)], {np.int64(1): np.float64(5.0), np.int64(2): 0}, - the minimum parent node size threshold has been reached)
+-- ([np.int64(2)], {np.int64(1): 0, np.int64(2): np.float64(5.0)}, - the minimum parent node size threshold has been reached)
Can this be solved by managing the data type, or is CHIAD only supporting numpy <=1.26.4?
EDIT: A workaround is to clean the output like below. But I expect that it can be avoided by specifying data type in numpy during computing of tree
d_clean = {int(k): float(v) for k, v in d.items()}