Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 444ac26

Browse files
antmarakisnorvig
authored andcommitted
Bug Fixing in DataSet + Test Updates (#410)
* Bugfixing * Test for "exclude" * Update test_learning.py * update_values
1 parent f624415 commit 444ac26

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

learning.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def hamming_distance(predictions, targets):
4343

4444

4545
class DataSet:
46-
"""A data set for a machine learning problem. It has the following fields:
46+
"""A data set for a machine learning problem. It has the following fields:
4747
48-
d.examples A list of examples. Each one is a list of attribute values.
48+
d.examples A list of examples. Each one is a list of attribute values.
4949
d.attrs A list of integers to index into an example, so example[attr]
5050
gives a value. Normally the same as range(len(d.examples[0])).
5151
d.attrnames Optional list of mnemonic names for corresponding attrs.
@@ -61,14 +61,16 @@ class DataSet:
6161
since that can handle any field types.
6262
d.name Name of the data set (for output display only).
6363
d.source URL or other source where the data came from.
64+
d.exclude A list of attribute indexes to exclude from d.inputs. Elements
65+
of this list can either be integers (attrs) or attrnames.
6466
6567
Normally, you call the constructor and you're done; then you just
6668
access fields like d.examples and d.target and d.inputs."""
6769

6870
def __init__(self, examples=None, attrs=None, attrnames=None, target=-1,
6971
inputs=None, values=None, distance=mean_boolean_error,
7072
name='', source='', exclude=()):
71-
"""Accepts any of DataSet's fields. Examples can also be a
73+
"""Accepts any of DataSet's fields. Examples can also be a
7274
string or file from which to parse examples using parse_csv.
7375
Optional parameter: exclude, as documented in .setproblem().
7476
>>> DataSet(examples='1, 2, 3')
@@ -108,14 +110,14 @@ def setproblem(self, target, inputs=None, exclude=()):
108110
to not use in inputs. Attributes can be -n .. n, or an attrname.
109111
Also computes the list of possible values, if that wasn't done yet."""
110112
self.target = self.attrnum(target)
111-
exclude = map(self.attrnum, exclude)
113+
exclude = list(map(self.attrnum, exclude))
112114
if inputs:
113115
self.inputs = removeall(self.target, inputs)
114116
else:
115117
self.inputs = [a for a in self.attrs
116118
if a != self.target and a not in exclude]
117119
if not self.values:
118-
self.values = list(map(unique, zip(*self.examples)))
120+
self.update_values()
119121
self.check_me()
120122

121123
def check_me(self):
@@ -150,6 +152,9 @@ def attrnum(self, attr):
150152
else:
151153
return attr
152154

155+
def update_values(self):
156+
self.values = list(map(unique, zip(*self.examples)))
157+
153158
def sanitize(self, example):
154159
"""Return a copy of example, with non-input attributes replaced by None."""
155160
return [attr_i if i in self.inputs else None
@@ -166,6 +171,7 @@ def classes_to_numbers(self,classes=None):
166171
def remove_examples(self,value=""):
167172
"""Remove examples that contain given value."""
168173
self.examples = [x for x in self.examples if value not in x]
174+
self.update_values()
169175

170176
def __repr__(self):
171177
return '<DataSet({}): {:d} examples, {:d} attributes>'.format(

tests/test_learning.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from utils import DataFile
55

66

7+
def test_exclude():
8+
iris = DataSet(name='iris', exclude=[3])
9+
assert iris.inputs == [0, 1, 2]
10+
711

812
def test_parse_csv():
913
Iris = DataFile('iris.csv').read()
@@ -38,6 +42,7 @@ def test_k_nearest_neighbors():
3842
kNN = NearestNeighborLearner(iris,k=3)
3943
assert kNN([5,3,1,0.1]) == "setosa"
4044

45+
4146
def test_decision_tree_learner():
4247
iris = DataSet(name="iris")
4348

@@ -47,21 +52,23 @@ def test_decision_tree_learner():
4752

4853
def test_neural_network_learner():
4954
iris = DataSet(name="iris")
55+
iris.remove_examples("virginica")
56+
5057
classes = ["setosa","versicolor","virginica"]
51-
5258
iris.classes_to_numbers()
5359

5460
nNL = NeuralNetLearner(iris)
55-
# NeuralNetLearner might be wrong. Just check if prediction is in range
61+
# NeuralNetLearner might be wrong. Just check if prediction is in range.
5662
assert nNL([5,3,1,0.1]) in range(len(classes))
5763

5864

5965
def test_perceptron():
6066
iris = DataSet(name="iris")
67+
iris.remove_examples("virginica")
68+
6169
classes = ["setosa","versicolor","virginica"]
62-
6370
iris.classes_to_numbers()
6471

6572
perceptron = PerceptronLearner(iris)
66-
# PerceptronLearner might be wrong. Just check if prediction is in range
73+
# PerceptronLearner might be wrong. Just check if prediction is in range.
6774
assert perceptron([5,3,1,0.1]) in range(len(classes))

0 commit comments

Comments
 (0)