|
1 | 1 | import pytest
|
2 |
| -from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \ |
3 |
| - PluralityLearner, NaiveBayesLearner, NearestNeighborLearner |
| 2 | +import math |
4 | 3 | from utils import DataFile
|
| 4 | +from learning import ( |
| 5 | + parse_csv, weighted_mode, weighted_replicate, DataSet, |
| 6 | + PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, |
| 7 | + rms_error, manhattan_distance, mean_boolean_error, mean_error |
| 8 | +) |
5 | 9 |
|
6 | 10 |
|
7 | 11 | def test_parse_csv():
|
@@ -33,3 +37,31 @@ def test_k_nearest_neighbors():
|
33 | 37 |
|
34 | 38 | kNN = NearestNeighborLearner(iris,k=3)
|
35 | 39 | assert kNN([5,3,1,0.1]) == "setosa"
|
| 40 | + |
| 41 | +def test_rms_error(): |
| 42 | + assert rms_error([2,2], [2,2]) == 0 |
| 43 | + assert rms_error((0,0), (0,1)) == math.sqrt(0.5) |
| 44 | + assert rms_error((1,0), (0,1)) == 1 |
| 45 | + assert rms_error((0,0), (0,-1)) == math.sqrt(0.5) |
| 46 | + assert rms_error((0,0.5), (0,-0.5)) == math.sqrt(0.5) |
| 47 | + |
| 48 | +def test_manhattan_distance(): |
| 49 | + assert manhattan_distance([2,2], [2,2]) == 0 |
| 50 | + assert manhattan_distance([0,0], [0,1]) == 1 |
| 51 | + assert manhattan_distance([1,0], [0,1]) == 2 |
| 52 | + assert manhattan_distance([0,0], [0,-1]) == 1 |
| 53 | + assert manhattan_distance([0,0.5], [0,-0.5]) == 1 |
| 54 | + |
| 55 | +def test_mean_boolean_error(): |
| 56 | + assert mean_boolean_error([1,1], [0,0]) == 1 |
| 57 | + assert mean_boolean_error([0,1], [1,0]) == 1 |
| 58 | + assert mean_boolean_error([1,1], [0,1]) == 0.5 |
| 59 | + assert mean_boolean_error([0,0], [0,0]) == 0 |
| 60 | + assert mean_boolean_error([1,1], [1,1]) == 0 |
| 61 | + |
| 62 | +def test_mean_error(): |
| 63 | + assert mean_error([2,2], [2,2]) == 0 |
| 64 | + assert mean_error([0,0], [0,1]) == 0.5 |
| 65 | + assert mean_error([1,0], [0,1]) == 1 |
| 66 | + assert mean_error([0,0], [0,-1]) == 0.5 |
| 67 | + assert mean_error([0,0.5], [0,-0.5]) == 0.5 |
0 commit comments