diff --git a/learning.py b/learning.py index df5d6fce3..d6986af24 100644 --- a/learning.py +++ b/learning.py @@ -830,13 +830,20 @@ def cross_validation_wrapper(learner, dataset, k=10, trials=1): err_val = [] err_train = [] size = 1 + while True: errT, errV = cross_validation(learner, size, dataset, k) # Check for convergence provided err_val is not empty - if (err_val and isclose(err_val[-1], errV, rel_tol=1e-6)): - best_size = size - return learner(dataset, best_size) - + if (err_train and isclose(err_train[-1], errT, rel_tol=1e-6)): + best_size = 0 + min_val = math.inf + + i = 0 + while i