1
1
"""Learn to estimate functions from examples. (Chapters 18-20)"""
2
2
3
3
from utils import *
4
- import agents , random , operator
4
+ import random
5
5
6
6
#______________________________________________________________________________
7
7
@@ -10,14 +10,14 @@ class DataSet:
10
10
11
11
d.examples A list of examples. Each one is a list of attribute values.
12
12
d.attrs A list of integers to index into an example, so example[attr]
13
- gives a value. Normally the same as range(len(d.examples)).
13
+ gives a value. Normally the same as range(len(d.examples[0] )).
14
14
d.attrnames Optional list of mnemonic names for corresponding attrs.
15
15
d.target The attribute that a learning algorithm will try to predict.
16
16
By default the final attribute.
17
17
d.inputs The list of attrs without the target.
18
18
d.values A list of lists: each sublist is the set of possible
19
- values for the corresponding attribute. If None, it
20
- is computed from the known examples by self.setproblem.
19
+ values for the corresponding attribute. If initially None,
20
+ it is computed from the known examples by self.setproblem.
21
21
If not None, an erroneous value raises ValueError.
22
22
d.name Name of the data set (for output display only).
23
23
d.source URL or other source where the data came from.
@@ -67,23 +67,31 @@ def setproblem(self, target, inputs=None, exclude=()):
67
67
if a is not self .target and a not in exclude ]
68
68
if not self .values :
69
69
self .values = map (unique , zip (* self .examples ))
70
+ self .check_me ()
71
+
72
+ def check_me (self ):
73
+ "Check that my fields make sense."
74
+ assert len (self .attrnames ) == len (self .attrs )
75
+ assert self .target in self .attrs
76
+ assert self .target not in self .inputs
77
+ assert set (self .inputs ).issubset (set (self .attrs ))
70
78
map (self .check_example , self .examples )
71
79
72
80
def add_example (self , example ):
73
- """ Add an example to the list of examples, checking it first."" "
81
+ "Add an example to the list of examples, checking it first."
74
82
self .check_example (example )
75
83
self .examples .append (example )
76
84
77
85
def check_example (self , example ):
78
- """ Raise ValueError if example has any invalid values."" "
86
+ "Raise ValueError if example has any invalid values."
79
87
if self .values :
80
88
for a in self .attrs :
81
89
if example [a ] not in self .values [a ]:
82
90
raise ValueError ('Bad value %s for attribute %s in %s' %
83
91
(example [a ], self .attrnames [a ], example ))
84
92
85
93
def attrnum (self , attr ):
86
- "Returns the number used for attr, which can be a name, or -n .. n."
94
+ "Returns the number used for attr, which can be a name, or -n .. n-1 ."
87
95
if attr < 0 :
88
96
return len (self .attrs ) + attr
89
97
elif isinstance (attr , str ):
@@ -166,7 +174,7 @@ def train(self, dataset):
166
174
## Initialize to 0
167
175
for gv in self .dataset .values [self .dataset .target ]:
168
176
N [gv ] = {}
169
- for attr in self .dataset .attrs :
177
+ for attr in self .dataset .inputs :
170
178
N [gv ][attr ] = {}
171
179
assert None not in self .dataset .values [attr ]
172
180
for val in self .dataset .values [attr ]:
@@ -175,7 +183,7 @@ def train(self, dataset):
175
183
## Go thru examples
176
184
for example in self .dataset .examples :
177
185
Ngv = N [example [self .dataset .target ]]
178
- for attr in self .dataset .attrs :
186
+ for attr in self .dataset .inputs :
179
187
Ngv [attr ][example [attr ]] += 1
180
188
Ngv [attr ][None ] += 1
181
189
self ._N = N
@@ -309,8 +317,8 @@ def plurality_value(self, examples):
309
317
"""Return the most popular target value for this set of examples.
310
318
(If target is binary, this is the majority; otherwise plurality.)"""
311
319
g = self .dataset .target
312
- return argmax (self .dataset .values [g ],
313
- lambda v : self .count (g , v , examples ))
320
+ return argmax_random_tie (self .dataset .values [g ],
321
+ lambda v : self .count (g , v , examples ))
314
322
315
323
def count (self , attr , val , examples ):
316
324
return count_if (lambda e : e [attr ] == val , examples )
@@ -338,7 +346,7 @@ def I(examples):
338
346
339
347
def split_by (self , attr , examples = None ):
340
348
"Return a list of (val, examples) pairs for each val of attr."
341
- if examples == None :
349
+ if examples is None :
342
350
examples = self .dataset .examples
343
351
return [(v , [e for e in examples if e [attr ] == v ])
344
352
for v in self .dataset .values [attr ]]
@@ -426,7 +434,7 @@ def predict(self, example):
426
434
def test (learner , dataset , examples = None , verbose = 0 ):
427
435
"""Return the proportion of the examples that are correctly predicted.
428
436
Assumes the learner has already been trained."""
429
- if examples == None : examples = dataset .examples
437
+ if examples is None : examples = dataset .examples
430
438
if len (examples ) == 0 : return 0.0
431
439
right = 0.0
432
440
for example in examples :
@@ -447,6 +455,7 @@ def train_and_test(learner, dataset, start, end):
447
455
examples = dataset .examples
448
456
try :
449
457
dataset .examples = examples [:start ] + examples [end :]
458
+ dataset .check_me ()
450
459
learner .train (dataset )
451
460
return test (learner , dataset , examples [start :end ])
452
461
finally :
@@ -456,7 +465,7 @@ def cross_validation(learner, dataset, k=10, trials=1):
456
465
"""Do k-fold cross_validate and return their mean.
457
466
That is, keep out 1/k of the examples for testing on each of k runs.
458
467
Shuffle the examples first; If trials>1, average over several shuffles."""
459
- if k == None :
468
+ if k is None :
460
469
k = len (dataset .examples )
461
470
if trials > 1 :
462
471
return mean ([cross_validation (learner , dataset , k , trials = 1 )
@@ -472,7 +481,7 @@ def leave1out(learner, dataset):
472
481
return cross_validation (learner , dataset , k = len (dataset .examples ))
473
482
474
483
def learningcurve (learner , dataset , trials = 10 , sizes = None ):
475
- if sizes == None :
484
+ if sizes is None :
476
485
sizes = range (2 , len (dataset .examples )- 10 , 2 )
477
486
def score (learner , size ):
478
487
random .shuffle (dataset .examples )
@@ -531,7 +540,7 @@ def T(attrname, branches):
531
540
[Fig. 18.6]
532
541
>>> restaurant_learner = DecisionTreeLearner()
533
542
>>> restaurant_learner.train(restaurant)
534
- >>> restaurant_learner.dt.display()
543
+ >>> restaurant_learner.dt.display() #doctest:+ELLIPSIS
535
544
Test Patrons
536
545
Patrons = None ==> RESULT = No
537
546
Patrons = Full ==> Test Hungry
@@ -540,7 +549,7 @@ def T(attrname, branches):
540
549
Type = Thai ==> Test Fri/Sat
541
550
Fri/Sat = Yes ==> RESULT = Yes
542
551
Fri/Sat = No ==> RESULT = No
543
- Type = French ==> RESULT = Yes
552
+ Type = French ==> RESULT = ...
544
553
Type = Italian ==> RESULT = No
545
554
Hungry = No ==> RESULT = No
546
555
Patrons = Some ==> RESULT = Yes
0 commit comments