Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2d711ad

Browse files
committed
Fixed type error in ProbDist.__setitem__ and factored out common logic.
1 parent 58df28f commit 2d711ad

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

probability.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,22 @@ class JointProbDist(ProbDist):
8787
>>> P = JointProbDist(['X', 'Y']); P[1, 1] = 0.25
8888
>>> P[1, 1]
8989
0.25
90+
>>> P[dict(X=0, Y=1)] = 0.5
91+
>>> P[dict(X=0, Y=1)]
92+
0.5
9093
"""
9194
def __init__(self, variables):
9295
update(self, prob={}, variables=variables, vals=DefaultDict([]))
9396

9497
def __getitem__(self, values):
9598
"Given a tuple or dict of values, return P(values)."
96-
if isinstance(values, dict):
97-
values = tuple([values[var] for var in self.variables])
98-
return self.prob[values]
99+
return self.prob[event_values(values, self.variables)]
99100

100101
def __setitem__(self, values, p):
101102
"""Set P(values) = p. Values can be a tuple or a dict; it must
102103
have a value for each of the variables in the joint. Also keep track
103104
of the values we have seen so far for each variable."""
104-
if isinstance(values, dict):
105-
values = [values[var] for var in self.variables]
105+
values = event_values(values, self.variables)
106106
self.prob[values] = p
107107
for var, val in zip(self.variables, values):
108108
if val not in self.vals[var]:
@@ -247,13 +247,16 @@ def rand(self, parents, event):
247247

248248
return (random() <= self.p(True, parents, event))
249249

250-
def event_values (event, vars):
250+
def event_values(event, vars):
251251
"""Return a tuple of the values of variables vars in event.
252252
253253
>>> event_values ({'A': 10, 'B': 9, 'C': 8}, ['C', 'A'])
254254
(8, 10)
255+
>>> event_values ((1, 2), ['C', 'A'])
256+
(1, 2)
255257
"""
256-
258+
if isinstance(event, tuple) and len(event) == len(vars):
259+
return event
257260
return tuple([event[parent] for parent in vars])
258261

259262

0 commit comments

Comments
 (0)