1
1
from mdp import *
2
2
3
+ sequential_decision_environment_1 = GridMDP ([[- 0.1 , - 0.1 , - 0.1 , + 1 ],
4
+ [- 0.1 , None , - 0.1 , - 1 ],
5
+ [- 0.1 , - 0.1 , - 0.1 , - 0.1 ]],
6
+ terminals = [(3 , 2 ), (3 , 1 )])
7
+
8
+ sequential_decision_environment_2 = GridMDP ([[- 2 , - 2 , - 2 , + 1 ],
9
+ [- 2 , None , - 2 , - 1 ],
10
+ [- 2 , - 2 , - 2 , - 2 ]],
11
+ terminals = [(3 , 2 ), (3 , 1 )])
12
+
13
+ sequential_decision_environment_3 = GridMDP ([[- 1.0 , - 0.1 , - 0.1 , - 0.1 , - 0.1 , 0.5 ],
14
+ [- 0.1 , None , None , - 0.5 , - 0.1 , - 0.1 ],
15
+ [- 0.1 , None , 1.0 , 3.0 , None , - 0.1 ],
16
+ [- 0.1 , - 0.1 , - 0.1 , None , None , - 0.1 ],
17
+ [0.5 , - 0.1 , - 0.1 , - 0.1 , - 0.1 , - 1.0 ]],
18
+ terminals = [(2 , 2 ), (3 , 2 ), (0 , 4 ), (5 , 0 )])
3
19
4
20
def test_value_iteration ():
5
21
assert value_iteration (sequential_decision_environment , .01 ) == {
@@ -10,6 +26,30 @@ def test_value_iteration():
10
26
(2 , 0 ): 0.34461306281476806 , (2 , 1 ): 0.48643676237737926 ,
11
27
(2 , 2 ): 0.79536093684710951 }
12
28
29
+ assert value_iteration (sequential_decision_environment_1 , .01 ) == {
30
+ (3 , 2 ): 1.0 , (3 , 1 ): - 1.0 ,
31
+ (3 , 0 ): - 0.0897388258468311 , (0 , 1 ): 0.146419707398967840 ,
32
+ (0 , 2 ): 0.30596200514385086 , (1 , 0 ): 0.010092796415625799 ,
33
+ (0 , 0 ): 0.00633408092008296 , (1 , 2 ): 0.507390193380827400 ,
34
+ (2 , 0 ): 0.15072242145212010 , (2 , 1 ): 0.358309043654212570 ,
35
+ (2 , 2 ): 0.71675493618997840 }
36
+
37
+ assert value_iteration (sequential_decision_environment_2 , .01 ) == {
38
+ (3 , 2 ): 1.0 , (3 , 1 ): - 1.0 ,
39
+ (3 , 0 ): - 3.5141584808407855 , (0 , 1 ): - 7.8000009574737180 ,
40
+ (0 , 2 ): - 6.1064293596058830 , (1 , 0 ): - 7.1012549580376760 ,
41
+ (0 , 0 ): - 8.5872244532783200 , (1 , 2 ): - 3.9653547121245810 ,
42
+ (2 , 0 ): - 5.3099468802901630 , (2 , 1 ): - 3.3543366255753995 ,
43
+ (2 , 2 ): - 1.7383376462930498 }
44
+
45
+ assert value_iteration (sequential_decision_environment_3 , .01 ) == {
46
+ (0 , 0 ): 4.350592130345558 , (0 , 1 ): 3.640700980321895 , (0 , 2 ): 3.0734806370346943 , (0 , 3 ): 2.5754335063434937 , (0 , 4 ): - 1.0 ,
47
+ (1 , 0 ): 3.640700980321895 , (1 , 1 ): 3.129579352304856 , (1 , 4 ): 2.0787517066719916 ,
48
+ (2 , 0 ): 3.0259220379893352 , (2 , 1 ): 2.5926103577982897 , (2 , 2 ): 1.0 , (2 , 4 ): 2.507774181360808 ,
49
+ (3 , 0 ): 2.5336747364500076 , (3 , 2 ): 3.0 , (3 , 3 ): 2.292172805400873 , (3 , 4 ): 2.996383110867515 ,
50
+ (4 , 0 ): 2.1014575936349886 , (4 , 3 ): 3.1297590518608907 , (4 , 4 ): 3.6408806798779287 ,
51
+ (5 , 0 ): - 1.0 , (5 , 1 ): 2.5756132058995282 , (5 , 2 ): 3.0736603365907276 , (5 , 3 ): 3.6408806798779287 , (5 , 4 ): 4.350771829901593 }
52
+
13
53
14
54
def test_policy_iteration ():
15
55
assert policy_iteration (sequential_decision_environment ) == {
@@ -18,6 +58,26 @@ def test_policy_iteration():
18
58
(2 , 1 ): (0 , 1 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (- 1 , 0 ),
19
59
(3 , 1 ): None , (3 , 2 ): None }
20
60
61
+ assert policy_iteration (sequential_decision_environment_1 ) == {
62
+ (0 , 0 ): (0 , 1 ), (0 , 1 ): (0 , 1 ), (0 , 2 ): (1 , 0 ),
63
+ (1 , 0 ): (1 , 0 ), (1 , 2 ): (1 , 0 ), (2 , 0 ): (0 , 1 ),
64
+ (2 , 1 ): (0 , 1 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (- 1 , 0 ),
65
+ (3 , 1 ): None , (3 , 2 ): None }
66
+
67
+ assert policy_iteration (sequential_decision_environment_2 ) == {
68
+ (0 , 0 ): (1 , 0 ), (0 , 1 ): (0 , 1 ), (0 , 2 ): (1 , 0 ),
69
+ (1 , 0 ): (1 , 0 ), (1 , 2 ): (1 , 0 ), (2 , 0 ): (1 , 0 ),
70
+ (2 , 1 ): (1 , 0 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (0 , 1 ),
71
+ (3 , 1 ): None , (3 , 2 ): None }
72
+
73
+ assert policy_iteration (sequential_decision_environment_3 ) == {
74
+ (0 , 0 ): (- 1 , 0 ), (0 , 1 ): (0 , - 1 ), (0 , 2 ): (0 , - 1 ), (0 , 3 ): (0 , - 1 ), (0 , 4 ): None ,
75
+ (1 , 0 ): (- 1 , 0 ), (1 , 1 ): (- 1 , 0 ), (1 , 4 ): (1 , 0 ),
76
+ (2 , 0 ): (- 1 , 0 ), (2 , 1 ): (0 , - 1 ), (2 , 2 ): None , (2 , 4 ): (1 , 0 ),
77
+ (3 , 0 ): (- 1 , 0 ), (3 , 2 ): None , (3 , 3 ): (1 , 0 ), (3 , 4 ): (1 , 0 ),
78
+ (4 , 0 ): (- 1 , 0 ), (4 , 3 ): (1 , 0 ), (4 , 4 ): (1 , 0 ),
79
+ (5 , 0 ): None , (5 , 1 ): (0 , 1 ), (5 , 2 ): (0 , 1 ), (5 , 3 ): (0 , 1 ), (5 , 4 ): (1 , 0 )}
80
+
21
81
22
82
def test_best_policy ():
23
83
pi = best_policy (sequential_decision_environment ,
@@ -26,6 +86,26 @@ def test_best_policy():
26
86
['^' , None , '^' , '.' ],
27
87
['^' , '>' , '^' , '<' ]]
28
88
89
+ pi_1 = best_policy (sequential_decision_environment_1 ,
90
+ value_iteration (sequential_decision_environment_1 , .01 ))
91
+ assert sequential_decision_environment_1 .to_arrows (pi_1 ) == [['>' , '>' , '>' , '.' ],
92
+ ['^' , None , '^' , '.' ],
93
+ ['^' , '>' , '^' , '<' ]]
94
+
95
+ pi_2 = best_policy (sequential_decision_environment_2 ,
96
+ value_iteration (sequential_decision_environment_2 , .01 ))
97
+ assert sequential_decision_environment_2 .to_arrows (pi_2 ) == [['>' , '>' , '>' , '.' ],
98
+ ['^' , None , '>' , '.' ],
99
+ ['>' , '>' , '>' , '^' ]]
100
+
101
+ pi_3 = best_policy (sequential_decision_environment_3 ,
102
+ value_iteration (sequential_decision_environment_3 , .01 ))
103
+ assert sequential_decision_environment_3 .to_arrows (pi_3 ) == [['.' , '>' , '>' , '>' , '>' , '>' ],
104
+ ['v' , None , None , '>' , '>' , '^' ],
105
+ ['v' , None , '.' , '.' , None , '^' ],
106
+ ['v' , '<' , 'v' , None , None , '^' ],
107
+ ['<' , '<' , '<' , '<' , '<' , '.' ]]
108
+
29
109
30
110
def test_transition_model ():
31
111
transition_model = {
0 commit comments