1
1
from learning import parse_csv , weighted_mode , weighted_replicate , DataSet , \
2
- PluralityLearner , NaiveBayesLearner , NearestNeighborLearner
2
+ PluralityLearner , NaiveBayesLearner , NearestNeighborLearner , \
3
+ NeuralNetLearner , PerceptronLearner , DecisionTreeLearner
3
4
from utils import DataFile
4
5
5
6
7
+
6
8
def test_parse_csv ():
7
9
Iris = DataFile ('iris.csv' ).read ()
8
- assert parse_csv (Iris )[0 ] == [5.1 , 3.5 , 1.4 , 0.2 , 'setosa' ]
10
+ assert parse_csv (Iris )[0 ] == [5.1 ,3.5 ,1.4 ,0.2 ,'setosa' ]
9
11
10
12
11
13
def test_weighted_mode ():
@@ -20,18 +22,46 @@ def test_plurality_learner():
20
22
zoo = DataSet (name = "zoo" )
21
23
22
24
pL = PluralityLearner (zoo )
23
- assert pL ([]) == "mammal"
25
+ assert pL ([1 , 0 , 0 , 1 , 0 , 0 , 0 , 1 , 1 , 1 , 0 , 0 , 4 , 1 , 0 , 1 ]) == "mammal"
24
26
25
27
26
28
def test_naive_bayes ():
27
29
iris = DataSet (name = "iris" )
28
30
29
31
nB = NaiveBayesLearner (iris )
30
- assert nB ([5 , 3 , 1 , 0.1 ]) == "setosa"
32
+ assert nB ([5 ,3 , 1 , 0.1 ]) == "setosa"
31
33
32
34
33
35
def test_k_nearest_neighbors ():
34
36
iris = DataSet (name = "iris" )
35
37
36
- kNN = NearestNeighborLearner (iris , k = 3 )
37
- assert kNN ([5 , 3 , 1 , 0.1 ]) == "setosa"
38
+ kNN = NearestNeighborLearner (iris ,k = 3 )
39
+ assert kNN ([5 ,3 ,1 ,0.1 ]) == "setosa"
40
+
41
+ def test_decision_tree_learner ():
42
+ iris = DataSet (name = "iris" )
43
+
44
+ dTL = DecisionTreeLearner (iris )
45
+ assert dTL ([5 ,3 ,1 ,0.1 ]) == "setosa"
46
+
47
+
48
+ def test_neural_network_learner ():
49
+ iris = DataSet (name = "iris" )
50
+ classes = ["setosa" ,"versicolor" ,"virginica" ]
51
+
52
+ iris .classes_to_numbers ()
53
+
54
+ nNL = NeuralNetLearner (iris )
55
+ # NeuralNetLearner might be wrong. Just check if prediction is in range
56
+ assert nNL ([5 ,3 ,1 ,0.1 ]) in range (len (classes ))
57
+
58
+
59
+ def test_perceptron ():
60
+ iris = DataSet (name = "iris" )
61
+ classes = ["setosa" ,"versicolor" ,"virginica" ]
62
+
63
+ iris .classes_to_numbers ()
64
+
65
+ perceptron = PerceptronLearner (iris )
66
+ # PerceptronLearner might be wrong. Just check if prediction is in range
67
+ assert perceptron ([5 ,3 ,1 ,0.1 ]) in range (len (classes ))
0 commit comments