Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit be60670

Browse files
committed
Fixed DataSet doc comments. Added rep-invariant check. DecisionTreeLearner now breaks ties on leaves randomly. NaiveBayesLearner now collects stats only on input fields. Minor cleanup.
1 parent 6810589 commit be60670

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

learning.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Learn to estimate functions from examples. (Chapters 18-20)"""
22

33
from utils import *
4-
import agents, random, operator
4+
import random
55

66
#______________________________________________________________________________
77

@@ -10,14 +10,14 @@ class DataSet:
1010
1111
d.examples A list of examples. Each one is a list of attribute values.
1212
d.attrs A list of integers to index into an example, so example[attr]
13-
gives a value. Normally the same as range(len(d.examples)).
13+
gives a value. Normally the same as range(len(d.examples[0])).
1414
d.attrnames Optional list of mnemonic names for corresponding attrs.
1515
d.target The attribute that a learning algorithm will try to predict.
1616
By default the final attribute.
1717
d.inputs The list of attrs without the target.
1818
d.values A list of lists: each sublist is the set of possible
19-
values for the corresponding attribute. If None, it
20-
is computed from the known examples by self.setproblem.
19+
values for the corresponding attribute. If initially None,
20+
it is computed from the known examples by self.setproblem.
2121
If not None, an erroneous value raises ValueError.
2222
d.name Name of the data set (for output display only).
2323
d.source URL or other source where the data came from.
@@ -67,23 +67,31 @@ def setproblem(self, target, inputs=None, exclude=()):
6767
if a is not self.target and a not in exclude]
6868
if not self.values:
6969
self.values = map(unique, zip(*self.examples))
70+
self.check_me()
71+
72+
def check_me(self):
73+
"Check that my fields make sense."
74+
assert len(self.attrnames) == len(self.attrs)
75+
assert self.target in self.attrs
76+
assert self.target not in self.inputs
77+
assert set(self.inputs).issubset(set(self.attrs))
7078
map(self.check_example, self.examples)
7179

7280
def add_example(self, example):
73-
"""Add an example to the list of examples, checking it first."""
81+
"Add an example to the list of examples, checking it first."
7482
self.check_example(example)
7583
self.examples.append(example)
7684

7785
def check_example(self, example):
78-
"""Raise ValueError if example has any invalid values."""
86+
"Raise ValueError if example has any invalid values."
7987
if self.values:
8088
for a in self.attrs:
8189
if example[a] not in self.values[a]:
8290
raise ValueError('Bad value %s for attribute %s in %s' %
8391
(example[a], self.attrnames[a], example))
8492

8593
def attrnum(self, attr):
86-
"Returns the number used for attr, which can be a name, or -n .. n."
94+
"Returns the number used for attr, which can be a name, or -n .. n-1."
8795
if attr < 0:
8896
return len(self.attrs) + attr
8997
elif isinstance(attr, str):
@@ -166,7 +174,7 @@ def train(self, dataset):
166174
## Initialize to 0
167175
for gv in self.dataset.values[self.dataset.target]:
168176
N[gv] = {}
169-
for attr in self.dataset.attrs:
177+
for attr in self.dataset.inputs:
170178
N[gv][attr] = {}
171179
assert None not in self.dataset.values[attr]
172180
for val in self.dataset.values[attr]:
@@ -175,7 +183,7 @@ def train(self, dataset):
175183
## Go thru examples
176184
for example in self.dataset.examples:
177185
Ngv = N[example[self.dataset.target]]
178-
for attr in self.dataset.attrs:
186+
for attr in self.dataset.inputs:
179187
Ngv[attr][example[attr]] += 1
180188
Ngv[attr][None] += 1
181189
self._N = N
@@ -309,8 +317,8 @@ def plurality_value(self, examples):
309317
"""Return the most popular target value for this set of examples.
310318
(If target is binary, this is the majority; otherwise plurality.)"""
311319
g = self.dataset.target
312-
return argmax(self.dataset.values[g],
313-
lambda v: self.count(g, v, examples))
320+
return argmax_random_tie(self.dataset.values[g],
321+
lambda v: self.count(g, v, examples))
314322

315323
def count(self, attr, val, examples):
316324
return count_if(lambda e: e[attr] == val, examples)
@@ -338,7 +346,7 @@ def I(examples):
338346

339347
def split_by(self, attr, examples=None):
340348
"Return a list of (val, examples) pairs for each val of attr."
341-
if examples == None:
349+
if examples is None:
342350
examples = self.dataset.examples
343351
return [(v, [e for e in examples if e[attr] == v])
344352
for v in self.dataset.values[attr]]
@@ -426,7 +434,7 @@ def predict(self, example):
426434
def test(learner, dataset, examples=None, verbose=0):
427435
"""Return the proportion of the examples that are correctly predicted.
428436
Assumes the learner has already been trained."""
429-
if examples == None: examples = dataset.examples
437+
if examples is None: examples = dataset.examples
430438
if len(examples) == 0: return 0.0
431439
right = 0.0
432440
for example in examples:
@@ -447,6 +455,7 @@ def train_and_test(learner, dataset, start, end):
447455
examples = dataset.examples
448456
try:
449457
dataset.examples = examples[:start] + examples[end:]
458+
dataset.check_me()
450459
learner.train(dataset)
451460
return test(learner, dataset, examples[start:end])
452461
finally:
@@ -456,7 +465,7 @@ def cross_validation(learner, dataset, k=10, trials=1):
456465
"""Do k-fold cross_validate and return their mean.
457466
That is, keep out 1/k of the examples for testing on each of k runs.
458467
Shuffle the examples first; If trials>1, average over several shuffles."""
459-
if k == None:
468+
if k is None:
460469
k = len(dataset.examples)
461470
if trials > 1:
462471
return mean([cross_validation(learner, dataset, k, trials=1)
@@ -472,7 +481,7 @@ def leave1out(learner, dataset):
472481
return cross_validation(learner, dataset, k=len(dataset.examples))
473482

474483
def learningcurve(learner, dataset, trials=10, sizes=None):
475-
if sizes == None:
484+
if sizes is None:
476485
sizes = range(2, len(dataset.examples)-10, 2)
477486
def score(learner, size):
478487
random.shuffle(dataset.examples)
@@ -531,7 +540,7 @@ def T(attrname, branches):
531540
[Fig. 18.6]
532541
>>> restaurant_learner = DecisionTreeLearner()
533542
>>> restaurant_learner.train(restaurant)
534-
>>> restaurant_learner.dt.display()
543+
>>> restaurant_learner.dt.display() #doctest:+ELLIPSIS
535544
Test Patrons
536545
Patrons = None ==> RESULT = No
537546
Patrons = Full ==> Test Hungry
@@ -540,7 +549,7 @@ def T(attrname, branches):
540549
Type = Thai ==> Test Fri/Sat
541550
Fri/Sat = Yes ==> RESULT = Yes
542551
Fri/Sat = No ==> RESULT = No
543-
Type = French ==> RESULT = Yes
552+
Type = French ==> RESULT = ...
544553
Type = Italian ==> RESULT = No
545554
Hungry = No ==> RESULT = No
546555
Patrons = Some ==> RESULT = Yes

0 commit comments

Comments
 (0)