|
1 | 1 | """Learn to estimate functions from examples. (Chapters 18-20)"""
|
2 | 2 |
|
3 | 3 | from utils import *
|
4 |
| -import heapq, math, random |
| 4 | +import copy, heapq, math, random |
5 | 5 | from collections import defaultdict
|
6 | 6 |
|
7 | 7 | #______________________________________________________________________________
|
@@ -433,6 +433,39 @@ def weighted_mode(values, weights):
|
433 | 433 | totals[v] += w
|
434 | 434 | return max(totals.keys(), key=totals.get)
|
435 | 435 |
|
| 436 | +#_____________________________________________________________________________ |
| 437 | +# Adapting an unweighted learner for AdaBoost |
| 438 | + |
| 439 | +def WeightedLearner(unweighted_learner): |
| 440 | + """Given a learner that takes just an unweighted dataset, return |
| 441 | + one that takes also a weight for each example. [p. 749 footnote 14]""" |
| 442 | + def train(dataset, weights): |
| 443 | + return unweighted_learner(replicated_dataset(dataset, weights)) |
| 444 | + return train |
| 445 | + |
| 446 | +def replicated_dataset(dataset, weights, n=None): |
| 447 | + """Copy dataset, replicating each example in proportion to the |
| 448 | + corresponding weight.""" |
| 449 | + n = n or len(dataset.examples) |
| 450 | + result = copy.copy(dataset) |
| 451 | + result.examples = weighted_replicate(dataset.examples, weights, n) |
| 452 | + return result |
| 453 | + |
| 454 | +def weighted_replicate(seq, weights, n): |
| 455 | + """Return n selections from seq, with the count of each element of |
| 456 | + seq proportional to the corresponding weight (filling in fractions |
| 457 | + randomly). |
| 458 | + >>> weighted_replicate('ABC', [1,2,1], 4) |
| 459 | + ['A', 'B', 'B', 'C']""" |
| 460 | + assert len(seq) == len(weights) |
| 461 | + weights = normalize(weights) |
| 462 | + wholes = [int(w*n) for w in weights] |
| 463 | + fractions = [(w*n) % 1 for w in weights] |
| 464 | + return (flatten([x] * nx for x, nx in zip(seq, wholes)) |
| 465 | + + weighted_sample_with_replacement(seq, fractions, n - sum(wholes))) |
| 466 | + |
| 467 | +def flatten(seqs): return sum(seqs, []) |
| 468 | + |
436 | 469 | #_____________________________________________________________________________
|
437 | 470 | # Functions for testing learners on examples
|
438 | 471 |
|
|
0 commit comments