diff --git a/tests/test_learning.py b/tests/test_learning.py index 4f618f7c1..60ea61366 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -41,6 +41,8 @@ def test_k_nearest_neighbors(): kNN = NearestNeighborLearner(iris,k=3) assert kNN([5,3,1,0.1]) == "setosa" + assert kNN([6,5,3,1.5]) == "versicolor" + assert kNN([7.5,4,6,2]) == "virginica" def test_decision_tree_learner(): @@ -48,6 +50,8 @@ def test_decision_tree_learner(): dTL = DecisionTreeLearner(iris) assert dTL([5,3,1,0.1]) == "setosa" + assert dTL([6,5,3,1.5]) == "versicolor" + assert dTL([7.5,4,6,2]) == "virginica" def test_neural_network_learner():