@@ -32,5 +32,61 @@ def tests():
32
32
p = likelihood_weighting ('Earthquake' , {}, burglary , 1000 )
33
33
assert p [True ], p [False ] == (0.002 , 0.998 )
34
34
35
+ def test_probdist_basic ():
36
+ P = ProbDist ('Flip' )
37
+ P ['H' ], P ['T' ] = 0.25 , 0.75 ;
38
+ assert P ['H' ] == 0.25
39
+
40
+ def test_probdist_frequency ():
41
+ P = ProbDist ('X' , {'lo' : 125 , 'med' : 375 , 'hi' : 500 })
42
+ assert (P ['lo' ], P ['med' ], P ['hi' ]) == (0.125 , 0.375 , 0.5 )
43
+
44
+ def test_probdist_normalize ():
45
+ P = ProbDist ('Flip' )
46
+ P ['H' ], P ['T' ] = 35 , 65
47
+ P = P .normalize ()
48
+ assert (P .prob ['H' ], P .prob ['T' ]) == (0.350 , 0.650 )
49
+
50
+ def test_jointprob ():
51
+ P = JointProbDist (['X' , 'Y' ])
52
+ P [1 , 1 ] = 0.25
53
+ assert P [1 , 1 ] == 0.25
54
+ P [dict (X = 0 , Y = 1 )] = 0.5
55
+ assert P [dict (X = 0 , Y = 1 )] == 0.5
56
+
57
+ def test_event_values ():
58
+ assert event_values ({'A' : 10 , 'B' : 9 , 'C' : 8 }, ['C' , 'A' ]) == (8 , 10 )
59
+ assert event_values ((1 , 2 ), ['C' , 'A' ]) == (1 , 2 )
60
+
61
+ def test_enumerate_joint_ask ():
62
+ P = JointProbDist (['X' , 'Y' ])
63
+ P [0 ,0 ] = 0.25
64
+ P [0 ,1 ] = 0.5
65
+ P [1 ,1 ] = P [2 ,1 ] = 0.125
66
+ assert enumerate_joint_ask ('X' , dict (Y = 1 ),
67
+ P ).show_approx () == '0: 0.667, 1: 0.167, 2: 0.167'
68
+
69
+ def test_bayesnode_p ():
70
+ bn = BayesNode ('X' , 'Burglary' , {T : 0.2 , F : 0.625 })
71
+ assert bn .p (False , {'Burglary' : False , 'Earthquake' : True }) == 0.375
72
+
73
+ def test_enumeration_ask ():
74
+ assert enumeration_ask ('Burglary' ,
75
+ dict (JohnCalls = T , MaryCalls = T ), burglary ).show_approx () == 'False: 0.716, True: 0.284'
76
+
77
+ def test_elemination_ask ():
78
+ elimination_ask ('Burglary' , dict (JohnCalls = T , MaryCalls = T ),
79
+ burglary ).show_approx () == 'False: 0.716, True: 0.284'
80
+
81
+ def test_rejection_sampling ():
82
+ random .seed (47 )
83
+ rejection_sampling ('Burglary' , dict (JohnCalls = T , MaryCalls = T ),
84
+ burglary , 10000 ).show_approx () == 'False: 0.7, True: 0.3'
85
+
86
+ def test_likelihood_weighting ():
87
+ random .seed (1017 )
88
+ assert likelihood_weighting ('Burglary' , dict (JohnCalls = T , MaryCalls = T ),
89
+ burglary , 10000 ).show_approx () == 'False: 0.702, True: 0.298'
90
+
35
91
if __name__ == '__main__' :
36
92
pytest .main ()
0 commit comments