Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 86db3ba

Browse files
committed
Move BoolCPT into the BayesNode class.
1 parent 1cae59f commit 86db3ba

File tree

1 file changed

+83
-94
lines changed

1 file changed

+83
-94
lines changed

probability.py

Lines changed: 83 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ def values(self, var):
107107
def __repr__(self):
108108
return "P(%s)" % self.variables
109109

110+
def event_values(event, vars):
111+
"""Return a tuple of the values of variables vars in event.
112+
>>> event_values ({'A': 10, 'B': 9, 'C': 8}, ['C', 'A'])
113+
(8, 10)
114+
>>> event_values ((1, 2), ['C', 'A'])
115+
(1, 2)
116+
"""
117+
if isinstance(event, tuple) and len(event) == len(vars):
118+
return event
119+
else:
120+
return tuple([event[var] for var in vars])
121+
110122
#______________________________________________________________________________
111123

112124
def enumerate_joint_ask(X, e, P):
@@ -135,86 +147,6 @@ def enumerate_joint(vars, e, P):
135147

136148
#______________________________________________________________________________
137149

138-
139-
class BoolCPT:
140-
"""A conditional probability table for a boolean (True/False)
141-
random variable conditioned on its parents."""
142-
143-
def __init__(self, table):
144-
"""table must take one of these forms:
145-
146-
* A number, the unconditional probability P(X=true). You can
147-
use this form when the variable has no parents.
148-
149-
* A dict {v: p, ...}, the conditional probability distribution
150-
P(X=true | parent=v) = p.
151-
152-
* A dict {(v1, v2, ...): p, ...}, the conditional probability
153-
distribution P(X=true | parent1=v1, parent2=v2, ...) = p.
154-
You can use this form always; the first two are just
155-
conveniences.
156-
157-
In all cases the probability of X being false is left implicit,
158-
since it follows from P(X=true).
159-
160-
>>> cpt = BoolCPT(0.2)
161-
>>> cpt = BoolCPT({T: 0.2, F: 0.7})
162-
>>> cpt = BoolCPT({(T, T): 0.2, (T, F): 0.3, (F, T): 0.5, (F, F): 0.7})
163-
"""
164-
# We store the table always in the third form above.
165-
if isinstance(table, (float, int)): # no parents, 0-tuple
166-
self.table = {(): table}
167-
elif isinstance(table, dict):
168-
if table: key = table.keys()[0]
169-
else: key = None
170-
if isinstance(key, bool): # one parent, 1-tuple
171-
self.table = dict(((k,), v) for k, v in table.items())
172-
elif isinstance(key, tuple): # normal case, n-tuple
173-
self.table = table
174-
else:
175-
raise Exception("wrong key type: %s" % table)
176-
else:
177-
raise Exception("wrong table type: %s" % table)
178-
179-
def p(self, value, parent_vars, event):
180-
"""Return the conditional probability
181-
P(X=value | parent_vars = parent_values), where parent_values
182-
are the values of parent_vars in event.
183-
184-
Preconditions:
185-
1. each variable in parent_vars is bound to a value in event.
186-
2. the variables are listed in parent_vars in the same order
187-
in which they are listed in the CPT.
188-
189-
>>> event = {'Burglary': False, 'Earthquake': True}
190-
>>> BoolCPT({T: 0.2, F: 0.625}).p(False, ['Burglary'], event)
191-
0.375"""
192-
assert isinstance(value, bool)
193-
ptrue = self.table[event_values(event, parent_vars)]
194-
return if_(value, ptrue, 1 - ptrue)
195-
196-
def sample(self, parent_vars, event):
197-
"""Sample from the distribution for this variable conditioned
198-
on event's values for parent_vars. That is, return True/False
199-
at random according with the conditional probability given
200-
event."""
201-
return random() <= self.p(True, parent_vars, event)
202-
203-
def event_values(event, vars):
204-
"""Return a tuple of the values of variables vars in event.
205-
>>> event_values ({'A': 10, 'B': 9, 'C': 8}, ['C', 'A'])
206-
(8, 10)
207-
>>> event_values ((1, 2), ['C', 'A'])
208-
(1, 2)
209-
"""
210-
if isinstance(event, tuple) and len(event) == len(vars):
211-
return event
212-
else:
213-
return tuple([event[var] for var in vars])
214-
215-
216-
#______________________________________________________________________________
217-
218150
class BayesNet:
219151
"Bayesian network containing only boolean-variable nodes."
220152

@@ -253,10 +185,72 @@ def variable_values(self, var):
253185

254186

255187
class BayesNode:
256-
def __init__(self, variable, parents, cpt):
188+
"""A conditional probability distribution for a boolean variable,
189+
P(X | parents). Part of a BayesNet."""
190+
191+
def __init__(self, X, parents, cpt):
192+
"""X is a variable name, and parents a sequence of variable
193+
names or a space-separated string. cpt, the conditional
194+
probability table, takes one of these forms:
195+
196+
* A number, the unconditional probability P(X=true). You can
197+
use this form when there are no parents.
198+
199+
* A dict {v: p, ...}, the conditional probability distribution
200+
P(X=true | parent=v) = p. When there's just one parent.
201+
202+
* A dict {(v1, v2, ...): p, ...}, the distribution P(X=true |
203+
parent1=v1, parent2=v2, ...) = p. You can use this form
204+
always; the first two are just conveniences.
205+
206+
In all cases the probability of X being false is left implicit,
207+
since it follows from P(X=true).
208+
209+
>>> X = BayesNode('X', '', 0.2)
210+
>>> Y = BayesNode('Y', 'P', {T: 0.2, F: 0.7})
211+
>>> Z = BayesNode('Z', 'P Q',
212+
... {(T, T): 0.2, (T, F): 0.3, (F, T): 0.5, (F, F): 0.7})"""
257213
if isinstance(parents, str): parents = parents.split()
258-
if not isinstance(cpt, BoolCPT): cpt = BoolCPT(cpt)
259-
update(self, variable=variable, parents=parents, cpt=cpt)
214+
215+
# We store the table always in the third form above.
216+
if isinstance(cpt, (float, int)): # no parents, 0-tuple
217+
cpt = {(): cpt}
218+
elif isinstance(cpt, dict):
219+
if cpt: key = cpt.keys()[0]
220+
else: key = None
221+
if isinstance(key, bool): # one parent, 1-tuple
222+
cpt = dict(((k,), v) for k, v in cpt.items())
223+
elif isinstance(key, tuple): # normal case, n-tuple
224+
pass
225+
else:
226+
raise Exception("wrong key type: %s" % cpt)
227+
else:
228+
raise Exception("wrong table type: %s" % cpt)
229+
230+
update(self, variable=X, parents=parents, cpt=cpt)
231+
232+
def p(self, value, event):
233+
"""Return the conditional probability
234+
P(X=value | parents = parent_values), where parent_values
235+
are the values of parents in event.
236+
237+
Preconditions:
238+
1. each variable in parents is bound to a value in event.
239+
in which they are listed in the CPT.
240+
XXX fix doctest
241+
>> event = {'Burglary': False, 'Earthquake': True}
242+
>> BoolCPT({T: 0.2, F: 0.625}).p(False, ['Burglary'], event)
243+
0.375"""
244+
assert isinstance(value, bool)
245+
ptrue = self.cpt[event_values(event, self.parents)]
246+
return if_(value, ptrue, 1 - ptrue)
247+
248+
def sample(self, event):
249+
"""Sample from the distribution for this variable conditioned
250+
on event's values for parent_vars. That is, return True/False
251+
at random according with the conditional probability given
252+
event."""
253+
return random() <= self.p(True, event)
260254

261255
node = BayesNode
262256

@@ -293,17 +287,12 @@ def enumerate_all(vars, e, bn):
293287
(the ones other than vars). Parents must precede children in vars."""
294288
if not vars:
295289
return 1.0
296-
297290
Y, rest = vars[0], vars[1:]
298291
Ynode = bn.variable_node(Y)
299-
parents, cpt = Ynode.parents, Ynode.cpt
300-
301292
if Y in e:
302-
y = e[Y]
303-
return cpt.p(y, parents, e) * enumerate_all(rest, e, bn)
293+
return Ynode.p(e[Y], e) * enumerate_all(rest, e, bn)
304294
else:
305-
return sum(cpt.p(y, parents, e)
306-
* enumerate_all(rest, extend(e, Y, y), bn)
295+
return sum(Ynode.p(y, e) * enumerate_all(rest, extend(e, Y, y), bn)
307296
for y in bn.variable_values(Y))
308297

309298
#______________________________________________________________________________
@@ -342,7 +331,7 @@ def prior_sample(bn):
342331
is a {variable: value} dict. [Fig. 14.13]"""
343332
event = {}
344333
for node in bn.nodes:
345-
event[node.variable] = node.cpt.sample(node.parents, event)
334+
event[node.variable] = node.sample(event)
346335
return event
347336

348337
#_______________________________________________________________________________
@@ -394,11 +383,11 @@ def weighted_sample(bn, e):
394383
w = 1
395384
event = dict(e) # boldface x in Fig. 14.15
396385
for node in bn.nodes:
397-
Xi, parents, cpt = node.variable, node.parents, node.cpt
386+
Xi = node.variable
398387
if Xi in e:
399-
w *= cpt.p(e[Xi], parents, event)
388+
w *= node.p(e[Xi], event)
400389
else:
401-
event[Xi] = cpt.sample(parents, event)
390+
event[Xi] = node.sample(event)
402391
return event, w
403392

404393

0 commit comments

Comments
 (0)