@@ -87,22 +87,22 @@ class JointProbDist(ProbDist):
87
87
>>> P = JointProbDist(['X', 'Y']); P[1, 1] = 0.25
88
88
>>> P[1, 1]
89
89
0.25
90
+ >>> P[dict(X=0, Y=1)] = 0.5
91
+ >>> P[dict(X=0, Y=1)]
92
+ 0.5
90
93
"""
91
94
def __init__ (self , variables ):
92
95
update (self , prob = {}, variables = variables , vals = DefaultDict ([]))
93
96
94
97
def __getitem__ (self , values ):
95
98
"Given a tuple or dict of values, return P(values)."
96
- if isinstance (values , dict ):
97
- values = tuple ([values [var ] for var in self .variables ])
98
- return self .prob [values ]
99
+ return self .prob [event_values (values , self .variables )]
99
100
100
101
def __setitem__ (self , values , p ):
101
102
"""Set P(values) = p. Values can be a tuple or a dict; it must
102
103
have a value for each of the variables in the joint. Also keep track
103
104
of the values we have seen so far for each variable."""
104
- if isinstance (values , dict ):
105
- values = [values [var ] for var in self .variables ]
105
+ values = event_values (values , self .variables )
106
106
self .prob [values ] = p
107
107
for var , val in zip (self .variables , values ):
108
108
if val not in self .vals [var ]:
@@ -247,13 +247,16 @@ def rand(self, parents, event):
247
247
248
248
return (random () <= self .p (True , parents , event ))
249
249
250
- def event_values (event , vars ):
250
+ def event_values (event , vars ):
251
251
"""Return a tuple of the values of variables vars in event.
252
252
253
253
>>> event_values ({'A': 10, 'B': 9, 'C': 8}, ['C', 'A'])
254
254
(8, 10)
255
+ >>> event_values ((1, 2), ['C', 'A'])
256
+ (1, 2)
255
257
"""
256
-
258
+ if isinstance (event , tuple ) and len (event ) == len (vars ):
259
+ return event
257
260
return tuple ([event [parent ] for parent in vars ])
258
261
259
262
0 commit comments