|
1 | 1 | """Learn to estimate functions from examples. (Chapters 18-20)"""
|
2 | 2 |
|
3 | 3 | from utils import *
|
4 |
| -import heapq, random |
| 4 | +import heapq, math, random |
5 | 5 |
|
6 | 6 | #______________________________________________________________________________
|
7 | 7 |
|
@@ -318,9 +318,7 @@ def split_by(attr, examples):
|
318 | 318 | def information_content(values):
|
319 | 319 | "Number of bits to represent the probability distribution in values."
|
320 | 320 | # If the values do not sum to 1, normalize them to make them a Prob. Dist.
|
321 |
| - values = removeall(0, values) |
322 |
| - s = float(sum(values)) |
323 |
| - if s != 1.0: values = [v/s for v in values] |
| 321 | + values = normalize(removeall(0, values)) |
324 | 322 | return sum([- v * log2(v) for v in values])
|
325 | 323 |
|
326 | 324 | #______________________________________________________________________________
|
@@ -394,6 +392,34 @@ def predict(example):
|
394 | 392 | return predict
|
395 | 393 | return train
|
396 | 394 |
|
| 395 | +#______________________________________________________________________________ |
| 396 | + |
| 397 | +def AdaBoost(L, K): |
| 398 | + """[Fig. 18.34]""" |
| 399 | + def train(dataset): |
| 400 | + examples, target = dataset.examples, dataset.target |
| 401 | + N = len(examples) |
| 402 | + w = [1./N] * N |
| 403 | + h, z = [], [] |
| 404 | + for k in range(K): |
| 405 | + h_k = L(dataset.examples, w) |
| 406 | + h.append(h_k) |
| 407 | + error = sum(weight for example, weight in zip(examples, w) |
| 408 | + if example[target] != h_k(example)) |
| 409 | + if error == 0: |
| 410 | + break |
| 411 | + assert error < 1, "AdaBoost's sub-learner misclassified everything" |
| 412 | + for j, example in enumerate(examples): |
| 413 | + if example[target] == h[k](example): |
| 414 | + w[j] *= error / (1. - error) |
| 415 | + w = normalize(w) |
| 416 | + z.append(math.log((1. - error) / error)) |
| 417 | + return WeightedMajority(h, z) |
| 418 | + return train |
| 419 | + |
| 420 | +def WeightedMajority(h, z): |
| 421 | + raise NotImplementedError |
| 422 | + |
397 | 423 | #_____________________________________________________________________________
|
398 | 424 | # Functions for testing learners on examples
|
399 | 425 |
|
|
0 commit comments