@@ -288,7 +288,7 @@ def enumerate_all(vars, e, bn):
288
288
#______________________________________________________________________________
289
289
290
290
def elimination_ask (X , e , bn ):
291
- """[Fig. 14.11]
291
+ """Compute bn's P(X|e) by variable elimination. [Fig. 14.11]
292
292
>>> elimination_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary
293
293
... ).show_approx()
294
294
'False: 0.716, True: 0.284'"""
@@ -300,9 +300,13 @@ def elimination_ask(X, e, bn):
300
300
return pointwise_product (factors , bn ).normalize ()
301
301
302
302
def is_hidden (var , X , e ):
303
+ "Is var a hidden variable when querying P(X|e)?"
303
304
return var != X and var not in e
304
305
305
306
def make_factor (var , e , bn ):
307
+ """Return the factor for var in bn's joint distribution given e.
308
+ That is, bn's full joint distribution, projected to accord with e,
309
+ is the pointwise product of these factors for bn's variables."""
306
310
node = bn .variable_node (var )
307
311
vars = [X for X in [var ] + node .parents if X not in e ]
308
312
cpt = dict ((event_values (e1 , vars ), node .p (e1 [var ], e1 ))
@@ -313,24 +317,28 @@ def pointwise_product(factors, bn):
313
317
return reduce (lambda f , g : f .pointwise_product (g , bn ), factors )
314
318
315
319
def sum_out (var , factors , bn ):
320
+ "Eliminate var from all factors by summing over its values."
316
321
result , var_factors = [], []
317
322
for f in factors :
318
323
(var_factors if var in f .vars else result ).append (f )
319
324
result .append (pointwise_product (var_factors , bn ).sum_out (var , bn ))
320
325
return result
321
326
322
327
class Factor :
328
+ "A factor in a joint distribution."
323
329
324
330
def __init__ (self , vars , cpt ):
325
331
update (self , vars = vars , cpt = cpt )
326
332
327
333
def pointwise_product (self , other , bn ):
334
+ "Multiply two factors, combining their variables."
328
335
vars = list (set (self .vars ) | set (other .vars ))
329
336
cpt = dict ((event_values (e , vars ), self .p (e ) * other .p (e ))
330
337
for e in all_events (vars , bn , {}))
331
338
return Factor (vars , cpt )
332
339
333
340
def sum_out (self , var , bn ):
341
+ "Make a factor eliminating var by summing over its values."
334
342
vars = [X for X in self .vars if X != var ]
335
343
cpt = dict ((event_values (e , vars ),
336
344
sum (self .p (extend (e , var , val ))
@@ -339,21 +347,24 @@ def sum_out(self, var, bn):
339
347
return Factor (vars , cpt )
340
348
341
349
def normalize (self ):
350
+ "Return my probabilities; must be down to one variable."
342
351
assert len (self .vars ) == 1
343
352
return ProbDist (self .vars [0 ],
344
353
dict ((k , v ) for ((k ,), v ) in self .cpt .items ()))
345
354
346
355
def p (self , e ):
356
+ "Look up my value tabulated for e."
347
357
return self .cpt [event_values (e , self .vars )]
348
358
349
- def all_events (vars , bn , e1 ):
359
+ def all_events (vars , bn , e ):
360
+ "Yield every way of extending e with values for all vars."
350
361
if not vars :
351
- yield e1
362
+ yield e
352
363
else :
353
364
X , rest = vars [0 ], vars [1 :]
354
- for e in all_events (rest , bn , e1 ):
365
+ for e1 in all_events (rest , bn , e ):
355
366
for x in bn .variable_values (X ):
356
- yield extend (e , X , x )
367
+ yield extend (e1 , X , x )
357
368
358
369
#______________________________________________________________________________
359
370
0 commit comments