@@ -287,26 +287,73 @@ def enumerate_all(vars, e, bn):
287
287
288
288
#______________________________________________________________________________
289
289
290
- def elimination_ask (X , e , bn , order = reversed ):
291
- "[Fig. 14.11]"
290
+ def elimination_ask (X , e , bn ):
291
+ """[Fig. 14.11]
292
+ >>> elimination_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary
293
+ ... ).show_approx()
294
+ 'False: 0.716, True: 0.284'"""
292
295
factors = []
293
- for var in order (bn .vars ):
294
- factors .append (Factor (var , e ))
296
+ for var in reversed (bn .vars ):
297
+ factors .append (make_factor (var , e , bn ))
295
298
if is_hidden (var , X , e ):
296
- factors = sum_out (var , factors )
297
- return pointwise_product (factors ).normalize ()
299
+ factors = sum_out (var , factors , bn )
300
+ return pointwise_product (factors , bn ).normalize ()
298
301
299
302
def is_hidden (var , X , e ):
300
303
return var != X and var not in e
301
304
302
- def Factor (var , e ):
303
- unimplemented ()
305
+ def make_factor (var , e , bn ):
306
+ node = bn .variable_node (var )
307
+ vars = [X for X in [var ] + node .parents if X not in e ]
308
+ cpt = dict ((event_values (e1 , vars ), node .p (e1 [var ], e1 ))
309
+ for e1 in all_events (vars , bn , e ))
310
+ return Factor (vars , cpt )
311
+
312
+ def pointwise_product (factors , bn ):
313
+ return reduce (lambda f , g : f .pointwise_product (g , bn ), factors )
314
+
315
+ def sum_out (var , factors , bn ):
316
+ result , var_factors = [], []
317
+ for f in factors :
318
+ (var_factors if var in f .vars else result ).append (f )
319
+ result .append (pointwise_product (var_factors , bn ).sum_out (var , bn ))
320
+ return result
321
+
322
+ class Factor :
323
+
324
+ def __init__ (self , vars , cpt ):
325
+ update (self , vars = vars , cpt = cpt )
326
+
327
+ def pointwise_product (self , other , bn ):
328
+ vars = list (set (self .vars ) | set (other .vars ))
329
+ cpt = dict ((event_values (e , vars ), self .p (e ) * other .p (e ))
330
+ for e in all_events (vars , bn , {}))
331
+ return Factor (vars , cpt )
332
+
333
+ def sum_out (self , var , bn ):
334
+ vars = [X for X in self .vars if X != var ]
335
+ cpt = dict ((event_values (e , vars ),
336
+ sum (self .p (extend (e , var , val ))
337
+ for val in bn .variable_values (var )))
338
+ for e in all_events (vars , bn , {}))
339
+ return Factor (vars , cpt )
304
340
305
- def pointwise_product (factors ):
306
- unimplemented ()
341
+ def normalize (self ):
342
+ assert len (self .vars ) == 1
343
+ return ProbDist (self .vars [0 ],
344
+ dict ((k , v ) for ((k ,), v ) in self .cpt .items ()))
307
345
308
- def sum_out (var , factors ):
309
- unimplemented ()
346
+ def p (self , e ):
347
+ return self .cpt [event_values (e , self .vars )]
348
+
349
+ def all_events (vars , bn , e1 ):
350
+ if not vars :
351
+ yield e1
352
+ else :
353
+ X , rest = vars [0 ], vars [1 :]
354
+ for e in all_events (rest , bn , e1 ):
355
+ for x in bn .variable_values (X ):
356
+ yield extend (e , X , x )
310
357
311
358
#______________________________________________________________________________
312
359
0 commit comments