diff --git a/search.py b/search.py index 873c03752..8bf742489 100644 --- a/search.py +++ b/search.py @@ -7,7 +7,7 @@ from utils import ( is_in, argmin, argmax, argmax_random_tie, probability, weighted_sampler, memoize, print_table, open_data, Stack, FIFOQueue, PriorityQueue, name, - distance + distance, vector_add ) from collections import defaultdict @@ -526,39 +526,37 @@ def and_search(states, problem, path): # body of and or search return or_search(problem.initial, problem, []) +# Pre-defined actions for PeakFindingProblem +directions4 = { 'W':(-1, 0), 'N':(0, 1), 'E':(1, 0), 'S':(0, -1) } +directions8 = dict(directions4) +directions8.update({'NW':(-1, 1), 'NE':(1, 1), 'SE':(1, -1), 'SW':(-1, -1) }) class PeakFindingProblem(Problem): """Problem of finding the highest peak in a limited grid""" - def __init__(self, initial, grid): + def __init__(self, initial, grid, defined_actions=directions4): """The grid is a 2 dimensional array/list whose state is specified by tuple of indices""" Problem.__init__(self, initial) self.grid = grid + self.defined_actions = defined_actions self.n = len(grid) assert self.n > 0 self.m = len(grid[0]) assert self.m > 0 def actions(self, state): - """Allows movement in only 4 directions""" - # TODO: Add flag to allow diagonal motion + """Returns the list of actions which are allowed to be taken from the given state""" allowed_actions = [] - if state[0] > 0: - allowed_actions.append('N') - if state[0] < self.n - 1: - allowed_actions.append('S') - if state[1] > 0: - allowed_actions.append('W') - if state[1] < self.m - 1: - allowed_actions.append('E') + for action in self.defined_actions: + next_state = vector_add(state, self.defined_actions[action]) + if next_state[0] >= 0 and next_state[1] >= 0 and next_state[0] <= self.n - 1 and next_state[1] <= self.m - 1: + allowed_actions.append(action) + return allowed_actions def result(self, state, action): """Moves in the direction specified by action""" - x, y = state - x = x + (1 if action == 'S' else (-1 if action == 'N' else 0)) - y = y + (1 if action == 'E' else (-1 if action == 'W' else 0)) - return (x, y) + return vector_add(state, self.defined_actions[action]) def value(self, state): """Value of a state is the value it is the index to""" @@ -1347,3 +1345,4 @@ def compare_graph_searchers(): GraphProblem('Q', 'WA', australia_map)], header=['Searcher', 'romania_map(Arad, Bucharest)', 'romania_map(Oradea, Neamt)', 'australia_map']) + diff --git a/tests/test_search.py b/tests/test_search.py index f22ca6f89..04cb2db35 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -88,12 +88,12 @@ def test_hill_climbing(): def test_simulated_annealing(): random.seed("aima-python") prob = PeakFindingProblem((0, 0), [[0, 5, 10, 20], - [-3, 7, 11, 5]]) + [-3, 7, 11, 5]], directions4) sols = {prob.value(simulated_annealing(prob)) for i in range(100)} assert max(sols) == 20 prob = PeakFindingProblem((0, 0), [[0, 5, 10, 8], [-3, 7, 9, 999], - [1, 2, 5, 11]]) + [1, 2, 5, 11]], directions8) sols = {prob.value(simulated_annealing(prob)) for i in range(100)} assert max(sols) == 999