|
7 | 7 | from utils import (
|
8 | 8 | is_in, argmin, argmax, argmax_random_tie, probability, weighted_sampler,
|
9 | 9 | memoize, print_table, open_data, Stack, FIFOQueue, PriorityQueue, name,
|
10 |
| - distance |
| 10 | + distance, vector_add |
11 | 11 | )
|
12 | 12 |
|
13 | 13 | from collections import defaultdict
|
@@ -526,39 +526,37 @@ def and_search(states, problem, path):
|
526 | 526 | # body of and or search
|
527 | 527 | return or_search(problem.initial, problem, [])
|
528 | 528 |
|
| 529 | +# Pre-defined actions for PeakFindingProblem |
| 530 | +directions4 = { 'W':(-1, 0), 'N':(0, 1), 'E':(1, 0), 'S':(0, -1) } |
| 531 | +directions8 = dict(directions4) |
| 532 | +directions8.update({'NW':(-1, 1), 'NE':(1, 1), 'SE':(1, -1), 'SW':(-1, -1) }) |
529 | 533 |
|
530 | 534 | class PeakFindingProblem(Problem):
|
531 | 535 | """Problem of finding the highest peak in a limited grid"""
|
532 | 536 |
|
533 |
| - def __init__(self, initial, grid): |
| 537 | + def __init__(self, initial, grid, defined_actions=directions4): |
534 | 538 | """The grid is a 2 dimensional array/list whose state is specified by tuple of indices"""
|
535 | 539 | Problem.__init__(self, initial)
|
536 | 540 | self.grid = grid
|
| 541 | + self.defined_actions = defined_actions |
537 | 542 | self.n = len(grid)
|
538 | 543 | assert self.n > 0
|
539 | 544 | self.m = len(grid[0])
|
540 | 545 | assert self.m > 0
|
541 | 546 |
|
542 | 547 | def actions(self, state):
|
543 |
| - """Allows movement in only 4 directions""" |
544 |
| - # TODO: Add flag to allow diagonal motion |
| 548 | + """Returns the list of actions which are allowed to be taken from the given state""" |
545 | 549 | allowed_actions = []
|
546 |
| - if state[0] > 0: |
547 |
| - allowed_actions.append('N') |
548 |
| - if state[0] < self.n - 1: |
549 |
| - allowed_actions.append('S') |
550 |
| - if state[1] > 0: |
551 |
| - allowed_actions.append('W') |
552 |
| - if state[1] < self.m - 1: |
553 |
| - allowed_actions.append('E') |
| 550 | + for action in self.defined_actions: |
| 551 | + next_state = vector_add(state, self.defined_actions[action]) |
| 552 | + if next_state[0] >= 0 and next_state[1] >= 0 and next_state[0] <= self.n - 1 and next_state[1] <= self.m - 1: |
| 553 | + allowed_actions.append(action) |
| 554 | + |
554 | 555 | return allowed_actions
|
555 | 556 |
|
556 | 557 | def result(self, state, action):
|
557 | 558 | """Moves in the direction specified by action"""
|
558 |
| - x, y = state |
559 |
| - x = x + (1 if action == 'S' else (-1 if action == 'N' else 0)) |
560 |
| - y = y + (1 if action == 'E' else (-1 if action == 'W' else 0)) |
561 |
| - return (x, y) |
| 559 | + return vector_add(state, self.defined_actions[action]) |
562 | 560 |
|
563 | 561 | def value(self, state):
|
564 | 562 | """Value of a state is the value it is the index to"""
|
@@ -1347,3 +1345,4 @@ def compare_graph_searchers():
|
1347 | 1345 | GraphProblem('Q', 'WA', australia_map)],
|
1348 | 1346 | header=['Searcher', 'romania_map(Arad, Bucharest)',
|
1349 | 1347 | 'romania_map(Oradea, Neamt)', 'australia_map'])
|
| 1348 | + |
0 commit comments