Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c6f81c5

Browse files
committed
Fill out elimination_ask() (uncommented).
1 parent bb8b235 commit c6f81c5

File tree

1 file changed

+59
-12
lines changed

1 file changed

+59
-12
lines changed

probability.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,26 +287,73 @@ def enumerate_all(vars, e, bn):
287287

288288
#______________________________________________________________________________
289289

290-
def elimination_ask(X, e, bn, order=reversed):
291-
"[Fig. 14.11]"
290+
def elimination_ask(X, e, bn):
291+
"""[Fig. 14.11]
292+
>>> elimination_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary
293+
... ).show_approx()
294+
'False: 0.716, True: 0.284'"""
292295
factors = []
293-
for var in order(bn.vars):
294-
factors.append(Factor(var, e))
296+
for var in reversed(bn.vars):
297+
factors.append(make_factor(var, e, bn))
295298
if is_hidden(var, X, e):
296-
factors = sum_out(var, factors)
297-
return pointwise_product(factors).normalize()
299+
factors = sum_out(var, factors, bn)
300+
return pointwise_product(factors, bn).normalize()
298301

299302
def is_hidden(var, X, e):
300303
return var != X and var not in e
301304

302-
def Factor(var, e):
303-
unimplemented()
305+
def make_factor(var, e, bn):
306+
node = bn.variable_node(var)
307+
vars = [X for X in [var] + node.parents if X not in e]
308+
cpt = dict((event_values(e1, vars), node.p(e1[var], e1))
309+
for e1 in all_events(vars, bn, e))
310+
return Factor(vars, cpt)
311+
312+
def pointwise_product(factors, bn):
313+
return reduce(lambda f, g: f.pointwise_product(g, bn), factors)
314+
315+
def sum_out(var, factors, bn):
316+
result, var_factors = [], []
317+
for f in factors:
318+
(var_factors if var in f.vars else result).append(f)
319+
result.append(pointwise_product(var_factors, bn).sum_out(var, bn))
320+
return result
321+
322+
class Factor:
323+
324+
def __init__(self, vars, cpt):
325+
update(self, vars=vars, cpt=cpt)
326+
327+
def pointwise_product(self, other, bn):
328+
vars = list(set(self.vars) | set(other.vars))
329+
cpt = dict((event_values(e, vars), self.p(e) * other.p(e))
330+
for e in all_events(vars, bn, {}))
331+
return Factor(vars, cpt)
332+
333+
def sum_out(self, var, bn):
334+
vars = [X for X in self.vars if X != var]
335+
cpt = dict((event_values(e, vars),
336+
sum(self.p(extend(e, var, val))
337+
for val in bn.variable_values(var)))
338+
for e in all_events(vars, bn, {}))
339+
return Factor(vars, cpt)
304340

305-
def pointwise_product(factors):
306-
unimplemented()
341+
def normalize(self):
342+
assert len(self.vars) == 1
343+
return ProbDist(self.vars[0],
344+
dict((k, v) for ((k,), v) in self.cpt.items()))
307345

308-
def sum_out(var, factors):
309-
unimplemented()
346+
def p(self, e):
347+
return self.cpt[event_values(e, self.vars)]
348+
349+
def all_events(vars, bn, e1):
350+
if not vars:
351+
yield e1
352+
else:
353+
X, rest = vars[0], vars[1:]
354+
for e in all_events(rest, bn, e1):
355+
for x in bn.variable_values(X):
356+
yield extend(e, X, x)
310357

311358
#______________________________________________________________________________
312359

0 commit comments

Comments
 (0)