Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d4520ca

Browse files
ad71norvig
authored andcommitted
Added more tests for mdp.py (aimacode#722)
1 parent 74dff56 commit d4520ca

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

tests/test_mdp.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
from mdp import *
22

3+
sequential_decision_environment_1 = GridMDP([[-0.1, -0.1, -0.1, +1],
4+
[-0.1, None, -0.1, -1],
5+
[-0.1, -0.1, -0.1, -0.1]],
6+
terminals=[(3, 2), (3, 1)])
7+
8+
sequential_decision_environment_2 = GridMDP([[-2, -2, -2, +1],
9+
[-2, None, -2, -1],
10+
[-2, -2, -2, -2]],
11+
terminals=[(3, 2), (3, 1)])
12+
13+
sequential_decision_environment_3 = GridMDP([[-1.0, -0.1, -0.1, -0.1, -0.1, 0.5],
14+
[-0.1, None, None, -0.5, -0.1, -0.1],
15+
[-0.1, None, 1.0, 3.0, None, -0.1],
16+
[-0.1, -0.1, -0.1, None, None, -0.1],
17+
[0.5, -0.1, -0.1, -0.1, -0.1, -1.0]],
18+
terminals=[(2, 2), (3, 2), (0, 4), (5, 0)])
319

420
def test_value_iteration():
521
assert value_iteration(sequential_decision_environment, .01) == {
@@ -10,6 +26,30 @@ def test_value_iteration():
1026
(2, 0): 0.34461306281476806, (2, 1): 0.48643676237737926,
1127
(2, 2): 0.79536093684710951}
1228

29+
assert value_iteration(sequential_decision_environment_1, .01) == {
30+
(3, 2): 1.0, (3, 1): -1.0,
31+
(3, 0): -0.0897388258468311, (0, 1): 0.146419707398967840,
32+
(0, 2): 0.30596200514385086, (1, 0): 0.010092796415625799,
33+
(0, 0): 0.00633408092008296, (1, 2): 0.507390193380827400,
34+
(2, 0): 0.15072242145212010, (2, 1): 0.358309043654212570,
35+
(2, 2): 0.71675493618997840}
36+
37+
assert value_iteration(sequential_decision_environment_2, .01) == {
38+
(3, 2): 1.0, (3, 1): -1.0,
39+
(3, 0): -3.5141584808407855, (0, 1): -7.8000009574737180,
40+
(0, 2): -6.1064293596058830, (1, 0): -7.1012549580376760,
41+
(0, 0): -8.5872244532783200, (1, 2): -3.9653547121245810,
42+
(2, 0): -5.3099468802901630, (2, 1): -3.3543366255753995,
43+
(2, 2): -1.7383376462930498}
44+
45+
assert value_iteration(sequential_decision_environment_3, .01) == {
46+
(0, 0): 4.350592130345558, (0, 1): 3.640700980321895, (0, 2): 3.0734806370346943, (0, 3): 2.5754335063434937, (0, 4): -1.0,
47+
(1, 0): 3.640700980321895, (1, 1): 3.129579352304856, (1, 4): 2.0787517066719916,
48+
(2, 0): 3.0259220379893352, (2, 1): 2.5926103577982897, (2, 2): 1.0, (2, 4): 2.507774181360808,
49+
(3, 0): 2.5336747364500076, (3, 2): 3.0, (3, 3): 2.292172805400873, (3, 4): 2.996383110867515,
50+
(4, 0): 2.1014575936349886, (4, 3): 3.1297590518608907, (4, 4): 3.6408806798779287,
51+
(5, 0): -1.0, (5, 1): 2.5756132058995282, (5, 2): 3.0736603365907276, (5, 3): 3.6408806798779287, (5, 4): 4.350771829901593}
52+
1353

1454
def test_policy_iteration():
1555
assert policy_iteration(sequential_decision_environment) == {
@@ -18,6 +58,26 @@ def test_policy_iteration():
1858
(2, 1): (0, 1), (2, 2): (1, 0), (3, 0): (-1, 0),
1959
(3, 1): None, (3, 2): None}
2060

61+
assert policy_iteration(sequential_decision_environment_1) == {
62+
(0, 0): (0, 1), (0, 1): (0, 1), (0, 2): (1, 0),
63+
(1, 0): (1, 0), (1, 2): (1, 0), (2, 0): (0, 1),
64+
(2, 1): (0, 1), (2, 2): (1, 0), (3, 0): (-1, 0),
65+
(3, 1): None, (3, 2): None}
66+
67+
assert policy_iteration(sequential_decision_environment_2) == {
68+
(0, 0): (1, 0), (0, 1): (0, 1), (0, 2): (1, 0),
69+
(1, 0): (1, 0), (1, 2): (1, 0), (2, 0): (1, 0),
70+
(2, 1): (1, 0), (2, 2): (1, 0), (3, 0): (0, 1),
71+
(3, 1): None, (3, 2): None}
72+
73+
assert policy_iteration(sequential_decision_environment_3) == {
74+
(0, 0): (-1, 0), (0, 1): (0, -1), (0, 2): (0, -1), (0, 3): (0, -1), (0, 4): None,
75+
(1, 0): (-1, 0), (1, 1): (-1, 0), (1, 4): (1, 0),
76+
(2, 0): (-1, 0), (2, 1): (0, -1), (2, 2): None, (2, 4): (1, 0),
77+
(3, 0): (-1, 0), (3, 2): None, (3, 3): (1, 0), (3, 4): (1, 0),
78+
(4, 0): (-1, 0), (4, 3): (1, 0), (4, 4): (1, 0),
79+
(5, 0): None, (5, 1): (0, 1), (5, 2): (0, 1), (5, 3): (0, 1), (5, 4): (1, 0)}
80+
2181

2282
def test_best_policy():
2383
pi = best_policy(sequential_decision_environment,
@@ -26,6 +86,26 @@ def test_best_policy():
2686
['^', None, '^', '.'],
2787
['^', '>', '^', '<']]
2888

89+
pi_1 = best_policy(sequential_decision_environment_1,
90+
value_iteration(sequential_decision_environment_1, .01))
91+
assert sequential_decision_environment_1.to_arrows(pi_1) == [['>', '>', '>', '.'],
92+
['^', None, '^', '.'],
93+
['^', '>', '^', '<']]
94+
95+
pi_2 = best_policy(sequential_decision_environment_2,
96+
value_iteration(sequential_decision_environment_2, .01))
97+
assert sequential_decision_environment_2.to_arrows(pi_2) == [['>', '>', '>', '.'],
98+
['^', None, '>', '.'],
99+
['>', '>', '>', '^']]
100+
101+
pi_3 = best_policy(sequential_decision_environment_3,
102+
value_iteration(sequential_decision_environment_3, .01))
103+
assert sequential_decision_environment_3.to_arrows(pi_3) == [['.', '>', '>', '>', '>', '>'],
104+
['v', None, None, '>', '>', '^'],
105+
['v', None, '.', '.', None, '^'],
106+
['v', '<', 'v', None, None, '^'],
107+
['<', '<', '<', '<', '<', '.']]
108+
29109

30110
def test_transition_model():
31111
transition_model = {

0 commit comments

Comments
 (0)