From 9511f61efd9c6c49cefb82d523109330acf98bb7 Mon Sep 17 00:00:00 2001 From: Antonis Maronikolakis Date: Sat, 18 Mar 2017 16:18:32 +0200 Subject: [PATCH] Update test_learning.py Add DecisionTreeLearner, NeuralNetLearner and PerceptronLearner tests --- tests/test_learning.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/test_learning.py b/tests/test_learning.py index 46ac8dd26..f216ad168 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -1,11 +1,13 @@ from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \ - PluralityLearner, NaiveBayesLearner, NearestNeighborLearner + PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \ + NeuralNetLearner, PerceptronLearner, DecisionTreeLearner from utils import DataFile + def test_parse_csv(): Iris = DataFile('iris.csv').read() - assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2, 'setosa'] + assert parse_csv(Iris)[0] == [5.1,3.5,1.4,0.2,'setosa'] def test_weighted_mode(): @@ -20,18 +22,46 @@ def test_plurality_learner(): zoo = DataSet(name="zoo") pL = PluralityLearner(zoo) - assert pL([]) == "mammal" + assert pL([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 4, 1, 0, 1]) == "mammal" def test_naive_bayes(): iris = DataSet(name="iris") nB = NaiveBayesLearner(iris) - assert nB([5, 3, 1, 0.1]) == "setosa" + assert nB([5,3,1,0.1]) == "setosa" def test_k_nearest_neighbors(): iris = DataSet(name="iris") - kNN = NearestNeighborLearner(iris, k=3) - assert kNN([5, 3, 1, 0.1]) == "setosa" + kNN = NearestNeighborLearner(iris,k=3) + assert kNN([5,3,1,0.1]) == "setosa" + +def test_decision_tree_learner(): + iris = DataSet(name="iris") + + dTL = DecisionTreeLearner(iris) + assert dTL([5,3,1,0.1]) == "setosa" + + +def test_neural_network_learner(): + iris = DataSet(name="iris") + classes = ["setosa","versicolor","virginica"] + + iris.classes_to_numbers() + + nNL = NeuralNetLearner(iris) + # NeuralNetLearner might be wrong. Just check if prediction is in range + assert nNL([5,3,1,0.1]) in range(len(classes)) + + +def test_perceptron(): + iris = DataSet(name="iris") + classes = ["setosa","versicolor","virginica"] + + iris.classes_to_numbers() + + perceptron = PerceptronLearner(iris) + # PerceptronLearner might be wrong. Just check if prediction is in range + assert perceptron([5,3,1,0.1]) in range(len(classes))