diff --git a/.travis.yml b/.travis.yml index 25750bac9..294287f9b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,6 +21,7 @@ install: - pip install numpy - pip install tensorflow - pip install opencv-python + - pip install sortedcontainers script: diff --git a/csp.py b/csp.py index e1ee53a89..8d0c754cb 100644 --- a/csp.py +++ b/csp.py @@ -1,9 +1,13 @@ """CSP (Constraint Satisfaction Problems) problems and solvers. (Chapter 6).""" +import string +from operator import eq, neg -from utils import argmin_random_tie, count, first +from sortedcontainers import SortedSet + +from utils import argmin_random_tie, count, first, extend import search -from collections import defaultdict +from collections import defaultdict, Counter from functools import reduce import itertools @@ -51,7 +55,6 @@ class CSP(search.Problem): def __init__(self, variables, domains, neighbors, constraints): """Construct a CSP problem. If variables is empty, it becomes domains.keys().""" variables = variables or list(domains.keys()) - self.variables = variables self.domains = domains self.neighbors = neighbors @@ -160,11 +163,20 @@ def conflicted_vars(self, current): # Constraint Propagation with AC-3 -def AC3(csp, queue=None, removals=None): +def no_arc_heuristic(csp, queue): + return queue + + +def dom_j_up(csp, queue): + return SortedSet(queue, key=lambda t: neg(len(csp.curr_domains[t[1]]))) + + +def AC3(csp, queue=None, removals=None, arc_heuristic=dom_j_up): """[Figure 6.3]""" if queue is None: queue = {(Xi, Xk) for Xi in csp.variables for Xk in csp.neighbors[Xi]} csp.support_pruning() + queue = arc_heuristic(csp, queue) while queue: (Xi, Xj) = queue.pop() if revise(csp, Xi, Xj, removals): @@ -187,6 +199,130 @@ def revise(csp, Xi, Xj, removals): return revised +# Constraint Propagation with AC-3b: an improved version of AC-3 with +# double-support domain-heuristic + +def AC3b(csp, queue=None, removals=None, arc_heuristic=dom_j_up): + if queue is None: + queue = {(Xi, Xk) for Xi in csp.variables for Xk in csp.neighbors[Xi]} + csp.support_pruning() + queue = arc_heuristic(csp, queue) + while queue: + (Xi, Xj) = queue.pop() + # Si_p values are all known to be supported by Xj + # Sj_p values are all known to be supported by Xi + # Dj - Sj_p = Sj_u values are unknown, as yet, to be supported by Xi + Si_p, Sj_p, Sj_u = partition(csp, Xi, Xj) + if not Si_p: + return False + revised = False + for x in set(csp.curr_domains[Xi]) - Si_p: + csp.prune(Xi, x, removals) + revised = True + if revised: + for Xk in csp.neighbors[Xi]: + if Xk != Xj: + queue.add((Xk, Xi)) + if (Xj, Xi) in queue: + if isinstance(queue, set): + # or queue -= {(Xj, Xi)} or queue.remove((Xj, Xi)) + queue.difference_update({(Xj, Xi)}) + else: + queue.difference_update((Xj, Xi)) + # the elements in D_j which are supported by Xi are given by the union of Sj_p with the set of those + # elements of Sj_u which further processing will show to be supported by some vi_p in Si_p + for vj_p in Sj_u: + for vi_p in Si_p: + conflict = True + if csp.constraints(Xj, vj_p, Xi, vi_p): + conflict = False + Sj_p.add(vj_p) + if not conflict: + break + revised = False + for x in set(csp.curr_domains[Xj]) - Sj_p: + csp.prune(Xj, x, removals) + revised = True + if revised: + for Xk in csp.neighbors[Xj]: + if Xk != Xi: + queue.add((Xk, Xj)) + return True + + +def partition(csp, Xi, Xj): + Si_p = set() + Sj_p = set() + Sj_u = set(csp.curr_domains[Xj]) + for vi_u in csp.curr_domains[Xi]: + conflict = True + # now, in order to establish support for a value vi_u in Di it seems better to try to find a support among + # the values in Sj_u first, because for each vj_u in Sj_u the check (vi_u, vj_u) is a double-support check + # and it is just as likely that any vj_u in Sj_u supports vi_u than it is that any vj_p in Sj_p does... + for vj_u in Sj_u - Sj_p: + # double-support check + if csp.constraints(Xi, vi_u, Xj, vj_u): + conflict = False + Si_p.add(vi_u) + Sj_p.add(vj_u) + if not conflict: + break + # ... and only if no support can be found among the elements in Sj_u, should the elements vj_p in Sj_p be used + # for single-support checks (vi_u, vj_p) + if conflict: + for vj_p in Sj_p: + # single-support check + if csp.constraints(Xi, vi_u, Xj, vj_p): + conflict = False + Si_p.add(vi_u) + if not conflict: + break + return Si_p, Sj_p, Sj_u - Sj_p + + +# Constraint Propagation with AC-4 + +def AC4(csp, queue=None, removals=None, arc_heuristic=dom_j_up): + if queue is None: + queue = {(Xi, Xk) for Xi in csp.variables for Xk in csp.neighbors[Xi]} + csp.support_pruning() + queue = arc_heuristic(csp, queue) + support_counter = Counter() + variable_value_pairs_supported = defaultdict(set) + unsupported_variable_value_pairs = [] + # construction and initialization of support sets + while queue: + (Xi, Xj) = queue.pop() + revised = False + for x in csp.curr_domains[Xi][:]: + for y in csp.curr_domains[Xj]: + if csp.constraints(Xi, x, Xj, y): + support_counter[(Xi, x, Xj)] += 1 + variable_value_pairs_supported[(Xj, y)].add((Xi, x)) + if support_counter[(Xi, x, Xj)] == 0: + csp.prune(Xi, x, removals) + revised = True + unsupported_variable_value_pairs.append((Xi, x)) + if revised: + if not csp.curr_domains[Xi]: + return False + # propagation of removed values + while unsupported_variable_value_pairs: + Xj, y = unsupported_variable_value_pairs.pop() + for Xi, x in variable_value_pairs_supported[(Xj, y)]: + revised = False + if x in csp.curr_domains[Xi][:]: + support_counter[(Xi, x, Xj)] -= 1 + if support_counter[(Xi, x, Xj)] == 0: + csp.prune(Xi, x, removals) + revised = True + unsupported_variable_value_pairs.append((Xi, x)) + if revised: + if not csp.curr_domains[Xi]: + return False + return True + + # ______________________________________________________________________________ # CSP Backtracking Search @@ -247,9 +383,9 @@ def forward_checking(csp, var, value, assignment, removals): return True -def mac(csp, var, value, assignment, removals): +def mac(csp, var, value, assignment, removals, constraint_propagation=AC3b): """Maintain arc consistency.""" - return AC3(csp, {(X, var) for X in csp.neighbors[var]}, removals) + return constraint_propagation(csp, {(X, var) for X in csp.neighbors[var]}, removals) # The search, proper @@ -283,11 +419,11 @@ def backtrack(assignment): # ______________________________________________________________________________ -# Min-conflicts hillclimbing search for CSPs +# Min-conflicts Hill Climbing search for CSPs def min_conflicts(csp, max_steps=100000): - """Solve a CSP by stochastic hillclimbing on the number of conflicts.""" + """Solve a CSP by stochastic Hill Climbing on the number of conflicts.""" # Generate a complete assignment for all variables (probably with conflicts) csp.current = current = {} for var in csp.variables: @@ -744,3 +880,526 @@ def solve_zebra(algorithm=min_conflicts, **args): print(var, end=' ') print() return ans['Zebra'], ans['Water'], z.nassigns, ans + + +# ______________________________________________________________________________ +# n-ary Constraint Satisfaction Problem + +class NaryCSP: + """A nary-CSP consists of + * domains, a dictionary that maps each variable to its domain + * constraints, a list of constraints + * variables, a set of variables + * var_to_const, a variable to set of constraints dictionary + """ + + def __init__(self, domains, constraints): + """domains is a variable:domain dictionary + constraints is a list of constraints + """ + self.variables = set(domains) + self.domains = domains + self.constraints = constraints + self.var_to_const = {var: set() for var in self.variables} + for con in constraints: + for var in con.scope: + self.var_to_const[var].add(con) + + def __str__(self): + """string representation of CSP""" + return str(self.domains) + + def display(self, assignment=None): + """more detailed string representation of CSP""" + if assignment is None: + assignment = {} + print('CSP(' + str(self.domains) + ', ' + str([str(c) for c in self.constraints]) + ') with assignment: ' + + str(assignment)) + + def consistent(self, assignment): + """assignment is a variable:value dictionary + returns True if all of the constraints that can be evaluated + evaluate to True given assignment. + """ + return all(con.holds(assignment) + for con in self.constraints + if all(v in assignment for v in con.scope)) + + +class Constraint: + """A Constraint consists of + * scope: a tuple of variables + * condition: a function that can applied to a tuple of values + for the variables + """ + + def __init__(self, scope, condition): + self.scope = scope + self.condition = condition + + def __repr__(self): + return self.condition.__name__ + str(self.scope) + + def holds(self, assignment): + """Returns the value of Constraint con evaluated in assignment. + + precondition: all variables are assigned in assignment + """ + return self.condition(*tuple(assignment[v] for v in self.scope)) + + +def all_diff(*values): + """Returns True if all values are different, False otherwise""" + return len(values) is len(set(values)) + + +def is_word(words): + """Returns True if the letters concatenated form a word in words, False otherwise""" + + def isw(*letters): + return "".join(letters) in words + + return isw + + +def meet_at(p1, p2): + """Returns a function that is True when the words meet at the positions (p1, p2), False otherwise""" + + def meets(w1, w2): + return w1[p1] == w2[p2] + + meets.__name__ = "meet_at(" + str(p1) + ',' + str(p2) + ')' + return meets + + +def adjacent(x, y): + """Returns True if x and y are adjacent numbers, False otherwise""" + return abs(x - y) == 1 + + +def sum_(n): + """Returns a function that is True when the the sum of all values is n, False otherwise""" + + def sumv(*values): + return sum(values) is n + + sumv.__name__ = str(n) + "==sum" + return sumv + + +def is_(val): + """Returns a function that is True when x is equal to val, False otherwise""" + + def isv(x): + return val == x + + isv.__name__ = str(val) + "==" + return isv + + +def ne_(val): + """Returns a function that is True when x is not equal to val, False otherwise""" + + def nev(x): + return val != x + + nev.__name__ = str(val) + "!=" + return nev + + +def no_heuristic(to_do): + return to_do + + +def sat_up(to_do): + return SortedSet(to_do, key=lambda t: 1 / len([var for var in t[1].scope])) + + +class ACSolver: + """Solves a CSP with arc consistency and domain splitting""" + + def __init__(self, csp): + """a CSP solver that uses arc consistency + * csp is the CSP to be solved + """ + self.csp = csp + + def GAC(self, orig_domains=None, to_do=None, arc_heuristic=sat_up): + """Makes this CSP arc-consistent using Generalized Arc Consistency + orig_domains is the original domains + to_do is a set of (variable,constraint) pairs + returns the reduced domains (an arc-consistent variable:domain dictionary) + """ + if orig_domains is None: + orig_domains = self.csp.domains + if to_do is None: + to_do = {(var, const) for const in self.csp.constraints + for var in const.scope} + else: + to_do = to_do.copy() + domains = orig_domains.copy() + to_do = arc_heuristic(to_do) + while to_do: + var, const = to_do.pop() + other_vars = [ov for ov in const.scope if ov != var] + if len(other_vars) == 0: + new_domain = {val for val in domains[var] + if const.holds({var: val})} + elif len(other_vars) == 1: + other = other_vars[0] + new_domain = {val for val in domains[var] + if any(const.holds({var: val, other: other_val}) + for other_val in domains[other])} + else: + new_domain = {val for val in domains[var] + if self.any_holds(domains, const, {var: val}, other_vars)} + if new_domain != domains[var]: + domains[var] = new_domain + if not new_domain: + return False, domains + add_to_do = self.new_to_do(var, const).difference(to_do) + to_do |= add_to_do + return True, domains + + def new_to_do(self, var, const): + """returns new elements to be added to to_do after assigning + variable var in constraint const. + """ + return {(nvar, nconst) for nconst in self.csp.var_to_const[var] + if nconst != const + for nvar in nconst.scope + if nvar != var} + + def any_holds(self, domains, const, env, other_vars, ind=0): + """returns True if Constraint const holds for an assignment + that extends env with the variables in other_vars[ind:] + env is a dictionary + Warning: this has side effects and changes the elements of env + """ + if ind == len(other_vars): + return const.holds(env) + else: + var = other_vars[ind] + for val in domains[var]: + # env = dict_union(env,{var:val}) # no side effects! + env[var] = val + holds = self.any_holds(domains, const, env, other_vars, ind + 1) + if holds: + return True + return False + + def domain_splitting(self, domains=None, to_do=None, arc_heuristic=sat_up): + """return a solution to the current CSP or False if there are no solutions + to_do is the list of arcs to check + """ + if domains is None: + domains = self.csp.domains + consistency, new_domains = self.GAC(domains, to_do, arc_heuristic) + if not consistency: + return False + elif all(len(new_domains[var]) == 1 for var in domains): + return {var: first(new_domains[var]) for var in domains} + else: + var = first(x for x in self.csp.variables if len(new_domains[x]) > 1) + if var: + dom1, dom2 = partition_domain(new_domains[var]) + new_doms1 = extend(new_domains, var, dom1) + new_doms2 = extend(new_domains, var, dom2) + to_do = self.new_to_do(var, None) + return self.domain_splitting(new_doms1, to_do, arc_heuristic) or \ + self.domain_splitting(new_doms2, to_do, arc_heuristic) + + +def partition_domain(dom): + """partitions domain dom into two""" + split = len(dom) // 2 + dom1 = set(list(dom)[:split]) + dom2 = dom - dom1 + return dom1, dom2 + + +class ACSearchSolver(search.Problem): + """A search problem with arc consistency and domain splitting + A node is a CSP """ + + def __init__(self, csp, arc_heuristic=sat_up): + self.cons = ACSolver(csp) + consistency, self.domains = self.cons.GAC(arc_heuristic=arc_heuristic) + if not consistency: + raise Exception('CSP is inconsistent') + self.heuristic = arc_heuristic + super().__init__(self.domains) + + def goal_test(self, node): + """node is a goal if all domains have 1 element""" + return all(len(node[var]) == 1 for var in node) + + def actions(self, state): + var = first(x for x in state if len(state[x]) > 1) + neighs = [] + if var: + dom1, dom2 = partition_domain(state[var]) + to_do = self.cons.new_to_do(var, None) + for dom in [dom1, dom2]: + new_domains = extend(state, var, dom) + consistency, cons_doms = self.cons.GAC(new_domains, to_do, self.heuristic) + if consistency: + neighs.append(cons_doms) + return neighs + + def result(self, state, action): + return action + + +def ac_solver(csp, arc_heuristic=sat_up): + """arc consistency (domain splitting)""" + return ACSolver(csp).domain_splitting(arc_heuristic=arc_heuristic) + + +def ac_search_solver(csp, arc_heuristic=sat_up): + """arc consistency (search interface)""" + from search import depth_first_tree_search + solution = None + try: + solution = depth_first_tree_search(ACSearchSolver(csp, arc_heuristic=arc_heuristic)).state + except: + return solution + if solution: + return {var: first(solution[var]) for var in solution} + + +# ______________________________________________________________________________ +# Crossword Problem + + +csp_crossword = NaryCSP({'one_across': {'ant', 'big', 'bus', 'car', 'has'}, + 'one_down': {'book', 'buys', 'hold', 'lane', 'year'}, + 'two_down': {'ginger', 'search', 'symbol', 'syntax'}, + 'three_across': {'book', 'buys', 'hold', 'land', 'year'}, + 'four_across': {'ant', 'big', 'bus', 'car', 'has'}}, + [Constraint(('one_across', 'one_down'), meet_at(0, 0)), + Constraint(('one_across', 'two_down'), meet_at(2, 0)), + Constraint(('three_across', 'two_down'), meet_at(2, 2)), + Constraint(('three_across', 'one_down'), meet_at(0, 2)), + Constraint(('four_across', 'two_down'), meet_at(0, 4))]) + +crossword1 = [['_', '_', '_', '*', '*'], + ['_', '*', '_', '*', '*'], + ['_', '_', '_', '_', '*'], + ['_', '*', '_', '*', '*'], + ['*', '*', '_', '_', '_'], + ['*', '*', '_', '*', '*']] + +words1 = {'ant', 'big', 'bus', 'car', 'has', 'book', 'buys', 'hold', + 'lane', 'year', 'ginger', 'search', 'symbol', 'syntax'} + + +class Crossword(NaryCSP): + + def __init__(self, puzzle, words): + domains = {} + constraints = [] + for i, line in enumerate(puzzle): + scope = [] + for j, element in enumerate(line): + if element == '_': + var = "p" + str(j) + str(i) + domains[var] = list(string.ascii_lowercase) + scope.append(var) + else: + if len(scope) > 1: + constraints.append(Constraint(tuple(scope), is_word(words))) + scope.clear() + if len(scope) > 1: + constraints.append(Constraint(tuple(scope), is_word(words))) + puzzle_t = list(map(list, zip(*puzzle))) + for i, line in enumerate(puzzle_t): + scope = [] + for j, element in enumerate(line): + if element == '_': + scope.append("p" + str(i) + str(j)) + else: + if len(scope) > 1: + constraints.append(Constraint(tuple(scope), is_word(words))) + scope.clear() + if len(scope) > 1: + constraints.append(Constraint(tuple(scope), is_word(words))) + super().__init__(domains, constraints) + self.puzzle = puzzle + + def display(self, assignment=None): + for i, line in enumerate(self.puzzle): + puzzle = "" + for j, element in enumerate(line): + if element == '*': + puzzle += "[*] " + else: + var = "p" + str(j) + str(i) + if assignment is not None: + if isinstance(assignment[var], set) and len(assignment[var]) is 1: + puzzle += "[" + str(first(assignment[var])).upper() + "] " + elif isinstance(assignment[var], str): + puzzle += "[" + str(assignment[var]).upper() + "] " + else: + puzzle += "[_] " + else: + puzzle += "[_] " + print(puzzle) + + +# ______________________________________________________________________________ +# Karuko Problem + + +# difficulty 0 +karuko1 = [['*', '*', '*', [6, ''], [3, '']], + ['*', [4, ''], [3, 3], '_', '_'], + [['', 10], '_', '_', '_', '_'], + [['', 3], '_', '_', '*', '*']] + +# difficulty 0 +karuko2 = [ + ['*', [10, ''], [13, ''], '*'], + [['', 3], '_', '_', [13, '']], + [['', 12], '_', '_', '_'], + [['', 21], '_', '_', '_']] + +# difficulty 1 +karuko3 = [ + ['*', [17, ''], [28, ''], '*', [42, ''], [22, '']], + [['', 9], '_', '_', [31, 14], '_', '_'], + [['', 20], '_', '_', '_', '_', '_'], + ['*', ['', 30], '_', '_', '_', '_'], + ['*', [22, 24], '_', '_', '_', '*'], + [['', 25], '_', '_', '_', '_', [11, '']], + [['', 20], '_', '_', '_', '_', '_'], + [['', 14], '_', '_', ['', 17], '_', '_']] + +# difficulty 2 +karuko4 = [ + ['*', '*', '*', '*', '*', [4, ''], [24, ''], [11, ''], '*', '*', '*', [11, ''], [17, ''], '*', '*'], + ['*', '*', '*', [17, ''], [11, 12], '_', '_', '_', '*', '*', [24, 10], '_', '_', [11, ''], '*'], + ['*', [4, ''], [16, 26], '_', '_', '_', '_', '_', '*', ['', 20], '_', '_', '_', '_', [16, '']], + [['', 20], '_', '_', '_', '_', [24, 13], '_', '_', [16, ''], ['', 12], '_', '_', [23, 10], '_', '_'], + [['', 10], '_', '_', [24, 12], '_', '_', [16, 5], '_', '_', [16, 30], '_', '_', '_', '_', '_'], + ['*', '*', [3, 26], '_', '_', '_', '_', ['', 12], '_', '_', [4, ''], [16, 14], '_', '_', '*'], + ['*', ['', 8], '_', '_', ['', 15], '_', '_', [34, 26], '_', '_', '_', '_', '_', '*', '*'], + ['*', ['', 11], '_', '_', [3, ''], [17, ''], ['', 14], '_', '_', ['', 8], '_', '_', [7, ''], [17, ''], '*'], + ['*', '*', '*', [23, 10], '_', '_', [3, 9], '_', '_', [4, ''], [23, ''], ['', 13], '_', '_', '*'], + ['*', '*', [10, 26], '_', '_', '_', '_', '_', ['', 7], '_', '_', [30, 9], '_', '_', '*'], + ['*', [17, 11], '_', '_', [11, ''], [24, 8], '_', '_', [11, 21], '_', '_', '_', '_', [16, ''], [17, '']], + [['', 29], '_', '_', '_', '_', '_', ['', 7], '_', '_', [23, 14], '_', '_', [3, 17], '_', '_'], + [['', 10], '_', '_', [3, 10], '_', '_', '*', ['', 8], '_', '_', [4, 25], '_', '_', '_', '_'], + ['*', ['', 16], '_', '_', '_', '_', '*', ['', 23], '_', '_', '_', '_', '_', '*', '*'], + ['*', '*', ['', 6], '_', '_', '*', '*', ['', 15], '_', '_', '_', '*', '*', '*', '*']] + + +class Karuko(NaryCSP): + + def __init__(self, puzzle): + variables = [] + for i, line in enumerate(puzzle): + # print line + for j, element in enumerate(line): + if element == '_': + var1 = str(i) + if len(var1) == 1: + var1 = "0" + var1 + var2 = str(j) + if len(var2) == 1: + var2 = "0" + var2 + variables.append("X" + var1 + var2) + domains = {} + for var in variables: + domains[var] = set(range(1, 10)) + constraints = [] + for i, line in enumerate(puzzle): + for j, element in enumerate(line): + if element != '_' and element != '*': + # down - column + if element[0] != '': + x = [] + for k in range(i + 1, len(puzzle)): + if puzzle[k][j] != '_': + break + var1 = str(k) + if len(var1) == 1: + var1 = "0" + var1 + var2 = str(j) + if len(var2) == 1: + var2 = "0" + var2 + x.append("X" + var1 + var2) + constraints.append(Constraint(x, sum_(element[0]))) + constraints.append(Constraint(x, all_diff)) + # right - line + if element[1] != '': + x = [] + for k in range(j + 1, len(puzzle[i])): + if puzzle[i][k] != '_': + break + var1 = str(i) + if len(var1) == 1: + var1 = "0" + var1 + var2 = str(k) + if len(var2) == 1: + var2 = "0" + var2 + x.append("X" + var1 + var2) + constraints.append(Constraint(x, sum_(element[1]))) + constraints.append(Constraint(x, all_diff)) + super().__init__(domains, constraints) + self.puzzle = puzzle + + def display(self, assignment=None): + for i, line in enumerate(self.puzzle): + puzzle = "" + for j, element in enumerate(line): + if element == '*': + puzzle += "[*]\t" + elif element == '_': + var1 = str(i) + if len(var1) == 1: + var1 = "0" + var1 + var2 = str(j) + if len(var2) == 1: + var2 = "0" + var2 + var = "X" + var1 + var2 + if assignment is not None: + if isinstance(assignment[var], set) and len(assignment[var]) is 1: + puzzle += "[" + str(first(assignment[var])) + "]\t" + elif isinstance(assignment[var], int): + puzzle += "[" + str(assignment[var]) + "]\t" + else: + puzzle += "[_]\t" + else: + puzzle += "[_]\t" + else: + puzzle += str(element[0]) + "\\" + str(element[1]) + "\t" + print(puzzle) + + +# ______________________________________________________________________________ +# Cryptarithmetic Problem + +# [Figure 6.2] +# T W O + T W O = F O U R +two_two_four = NaryCSP({'T': set(range(1, 10)), 'F': set(range(1, 10)), + 'W': set(range(0, 10)), 'O': set(range(0, 10)), 'U': set(range(0, 10)), 'R': set(range(0, 10)), + 'C1': set(range(0, 2)), 'C2': set(range(0, 2)), 'C3': set(range(0, 2))}, + [Constraint(('T', 'F', 'W', 'O', 'U', 'R'), all_diff), + Constraint(('O', 'R', 'C1'), lambda o, r, c1: o + o == r + 10 * c1), + Constraint(('W', 'U', 'C1', 'C2'), lambda w, u, c1, c2: c1 + w + w == u + 10 * c2), + Constraint(('T', 'O', 'C2', 'C3'), lambda t, o, c2, c3: c2 + t + t == o + 10 * c3), + Constraint(('F', 'C3'), eq)]) + +# S E N D + M O R E = M O N E Y +send_more_money = NaryCSP({'S': set(range(1, 10)), 'M': set(range(1, 10)), + 'E': set(range(0, 10)), 'N': set(range(0, 10)), 'D': set(range(0, 10)), + 'O': set(range(0, 10)), 'R': set(range(0, 10)), 'Y': set(range(0, 10)), + 'C1': set(range(0, 2)), 'C2': set(range(0, 2)), 'C3': set(range(0, 2)), + 'C4': set(range(0, 2))}, + [Constraint(('S', 'E', 'N', 'D', 'M', 'O', 'R', 'Y'), all_diff), + Constraint(('D', 'E', 'Y', 'C1'), lambda d, e, y, c1: d + e == y + 10 * c1), + Constraint(('N', 'R', 'E', 'C1', 'C2'), lambda n, r, e, c1, c2: c1 + n + r == e + 10 * c2), + Constraint(('E', 'O', 'N', 'C2', 'C3'), lambda e, o, n, c2, c3: c2 + e + o == n + 10 * c3), + Constraint(('S', 'M', 'O', 'C3', 'C4'), lambda s, m, o, c3, c4: c3 + s + m == o + 10 * c4), + Constraint(('M', 'C4'), eq)]) diff --git a/logic.py b/logic.py index 744d6a092..62c23bf46 100644 --- a/logic.py +++ b/logic.py @@ -39,8 +39,8 @@ from search import astar_search, PlanRoute from utils import ( removeall, unique, first, argmax, probability, - isnumber, issequence, Expr, expr, subexpressions -) + isnumber, issequence, Expr, expr, subexpressions, + extend) # ______________________________________________________________________________ @@ -1389,16 +1389,6 @@ def occur_check(var, x, s): return False -def extend(s, var, val): - """Copy the substitution s and extend it by setting var to val; return copy. - >>> extend({x: 1}, y, 2) == {x: 1, y: 2} - True - """ - s2 = s.copy() - s2[var] = val - return s2 - - def subst(s, x): """Substitute the substitution s into the expression x. >>> subst({x: 42, y:0}, F(x) + y) diff --git a/planning.py b/planning.py index 23362b59f..f37c3d663 100644 --- a/planning.py +++ b/planning.py @@ -7,6 +7,7 @@ from functools import reduce as _reduce import search +from csp import sat_up, NaryCSP, Constraint, ac_search_solver, is_ from logic import FolKB, conjuncts, unify, associate, SAT_plan, dpll_satisfiable from search import Node from utils import Expr, expr, first @@ -19,10 +20,11 @@ class PlanningProblem: The conjunction of these logical statements completely defines a state. """ - def __init__(self, initial, goals, actions): - self.initial = self.convert(initial) + def __init__(self, initial, goals, actions, domain=None): + self.initial = self.convert(initial) if domain is None else self.convert(initial) + self.convert(domain) self.goals = self.convert(goals) self.actions = actions + self.domain = domain def convert(self, clauses): """Converts strings into exprs""" @@ -44,9 +46,50 @@ def convert(self, clauses): new_clauses.append(clause) return new_clauses + def expand_fluents(self, name=None): + + kb = None + if self.domain: + kb = FolKB(self.convert(self.domain)) + for action in self.actions: + if action.precond: + for fests in set(action.precond).union(action.effect).difference(self.convert(action.domain)): + if fests.op[:3] != 'Not': + kb.tell(expr(str(action.domain) + ' ==> ' + str(fests))) + + objects = set(arg for clause in set(self.initial + self.goals) for arg in clause.args) + fluent_list = [] + if name is not None: + for fluent in self.initial + self.goals: + if str(fluent) == name: + fluent_list.append(fluent) + break + else: + fluent_list = list(map(lambda fluent: Expr(fluent[0], *fluent[1]), + {fluent.op: fluent.args for fluent in self.initial + self.goals + + [clause for action in self.actions for clause in action.effect if + clause.op[:3] != 'Not']}.items())) + + expansions = [] + for fluent in fluent_list: + for permutation in itertools.permutations(objects, len(fluent.args)): + new_fluent = Expr(fluent.op, *permutation) + if (self.domain and kb.ask(new_fluent) is not False) or not self.domain: + expansions.append(new_fluent) + + return expansions + def expand_actions(self, name=None): """Generate all possible actions with variable bindings for precondition selection heuristic""" + has_domains = all(action.domain for action in self.actions if action.precond) + kb = None + if has_domains: + kb = FolKB(self.initial) + for action in self.actions: + if action.precond: + kb.tell(expr(str(action.domain) + ' ==> ' + str(action))) + objects = set(arg for clause in self.initial for arg in clause.args) expansions = [] action_list = [] @@ -69,27 +112,29 @@ def expand_actions(self, name=None): else: new_args.append(arg) new_expr = Expr(str(action.name), *new_args) - new_preconds = [] - for precond in action.precond: - new_precond_args = [] - for arg in precond.args: - if arg in bindings: - new_precond_args.append(bindings[arg]) - else: - new_precond_args.append(arg) - new_precond = Expr(str(precond.op), *new_precond_args) - new_preconds.append(new_precond) - new_effects = [] - for effect in action.effect: - new_effect_args = [] - for arg in effect.args: - if arg in bindings: - new_effect_args.append(bindings[arg]) - else: - new_effect_args.append(arg) - new_effect = Expr(str(effect.op), *new_effect_args) - new_effects.append(new_effect) - expansions.append(Action(new_expr, new_preconds, new_effects)) + if (has_domains and kb.ask(new_expr) is not False) or ( + has_domains and not action.precond) or not has_domains: + new_preconds = [] + for precond in action.precond: + new_precond_args = [] + for arg in precond.args: + if arg in bindings: + new_precond_args.append(bindings[arg]) + else: + new_precond_args.append(arg) + new_precond = Expr(str(precond.op), *new_precond_args) + new_preconds.append(new_precond) + new_effects = [] + for effect in action.effect: + new_effect_args = [] + for arg in effect.args: + if arg in bindings: + new_effect_args.append(bindings[arg]) + else: + new_effect_args.append(arg) + new_effect = Expr(str(effect.op), *new_effect_args) + new_effects.append(new_effect) + expansions.append(Action(new_expr, new_preconds, new_effects)) return expansions @@ -132,13 +177,14 @@ class Action: eat = Action(expr("Eat(person, food)"), precond, effect) """ - def __init__(self, action, precond, effect): + def __init__(self, action, precond, effect, domain=None): if isinstance(action, str): action = expr(action) self.name = action.op self.args = action.args - self.precond = self.convert(precond) + self.precond = self.convert(precond) if domain is None else self.convert(precond) + self.convert(domain) self.effect = self.convert(effect) + self.domain = domain def __call__(self, kb, args): return self.act(kb, args) @@ -252,19 +298,21 @@ def air_cargo(): >>> """ - return PlanningProblem( - initial='At(C1, SFO) & At(C2, JFK) & At(P1, SFO) & At(P2, JFK) & ' - 'Cargo(C1) & Cargo(C2) & Plane(P1) & Plane(P2) & Airport(SFO) & Airport(JFK)', - goals='At(C1, JFK) & At(C2, SFO)', - actions=[Action('Load(c, p, a)', - precond='At(c, a) & At(p, a) & Cargo(c) & Plane(p) & Airport(a)', - effect='In(c, p) & ~At(c, a)'), - Action('Unload(c, p, a)', - precond='In(c, p) & At(p, a) & Cargo(c) & Plane(p) & Airport(a)', - effect='At(c, a) & ~In(c, p)'), - Action('Fly(p, f, to)', - precond='At(p, f) & Plane(p) & Airport(f) & Airport(to)', - effect='At(p, to) & ~At(p, f)')]) + return PlanningProblem(initial='At(C1, SFO) & At(C2, JFK) & At(P1, SFO) & At(P2, JFK)', + goals='At(C1, JFK) & At(C2, SFO)', + actions=[Action('Load(c, p, a)', + precond='At(c, a) & At(p, a)', + effect='In(c, p) & ~At(c, a)', + domain='Cargo(c) & Plane(p) & Airport(a)'), + Action('Unload(c, p, a)', + precond='In(c, p) & At(p, a)', + effect='At(c, a) & ~In(c, p)', + domain='Cargo(c) & Plane(p) & Airport(a)'), + Action('Fly(p, f, to)', + precond='At(p, f)', + effect='At(p, to) & ~At(p, f)', + domain='Plane(p) & Airport(f) & Airport(to)')], + domain='Cargo(C1) & Cargo(C2) & Plane(P1) & Plane(P2) & Airport(SFO) & Airport(JFK)') def spare_tire(): @@ -288,18 +336,21 @@ def spare_tire(): >>> """ - return PlanningProblem(initial='Tire(Flat) & Tire(Spare) & At(Flat, Axle) & At(Spare, Trunk)', + return PlanningProblem(initial='At(Flat, Axle) & At(Spare, Trunk)', goals='At(Spare, Axle) & At(Flat, Ground)', actions=[Action('Remove(obj, loc)', precond='At(obj, loc)', - effect='At(obj, Ground) & ~At(obj, loc)'), + effect='At(obj, Ground) & ~At(obj, loc)', + domain='Tire(obj)'), Action('PutOn(t, Axle)', - precond='Tire(t) & At(t, Ground) & ~At(Flat, Axle)', - effect='At(t, Axle) & ~At(t, Ground)'), + precond='At(t, Ground) & ~At(Flat, Axle)', + effect='At(t, Axle) & ~At(t, Ground)', + domain='Tire(t)'), Action('LeaveOvernight', precond='', effect='~At(Spare, Ground) & ~At(Spare, Axle) & ~At(Spare, Trunk) & \ - ~At(Flat, Ground) & ~At(Flat, Axle) & ~At(Flat, Trunk)')]) + ~At(Flat, Ground) & ~At(Flat, Axle) & ~At(Flat, Trunk)')], + domain='Tire(Flat) & Tire(Spare)') def three_block_tower(): @@ -323,16 +374,17 @@ def three_block_tower(): True >>> """ - - return PlanningProblem( - initial='On(A, Table) & On(B, Table) & On(C, A) & Block(A) & Block(B) & Block(C) & Clear(B) & Clear(C)', - goals='On(A, B) & On(B, C)', - actions=[Action('Move(b, x, y)', - precond='On(b, x) & Clear(b) & Clear(y) & Block(b) & Block(y)', - effect='On(b, y) & Clear(x) & ~On(b, x) & ~Clear(y)'), - Action('MoveToTable(b, x)', - precond='On(b, x) & Clear(b) & Block(b)', - effect='On(b, Table) & Clear(x) & ~On(b, x)')]) + return PlanningProblem(initial='On(A, Table) & On(B, Table) & On(C, A) & Clear(B) & Clear(C)', + goals='On(A, B) & On(B, C)', + actions=[Action('Move(b, x, y)', + precond='On(b, x) & Clear(b) & Clear(y)', + effect='On(b, y) & Clear(x) & ~On(b, x) & ~Clear(y)', + domain='Block(b) & Block(y)'), + Action('MoveToTable(b, x)', + precond='On(b, x) & Clear(b)', + effect='On(b, Table) & Clear(x) & ~On(b, x)', + domain='Block(b) & Block(x)')], + domain='Block(A) & Block(B) & Block(C)') def simple_blocks_world(): @@ -425,10 +477,14 @@ def shopping_problem(): goals='Have(Milk) & Have(Banana) & Have(Drill)', actions=[Action('Buy(x, store)', precond='At(store) & Sells(store, x)', - effect='Have(x)'), + effect='Have(x)', + domain='Store(store) & Item(x)'), Action('Go(x, y)', precond='At(x)', - effect='At(y) & ~At(x)')]) + effect='At(y) & ~At(x)', + domain='Place(x) & Place(y)')], + domain='Place(Home) & Place(SM) & Place(HW) & Store(SM) & Store(HW) & ' + 'Item(Milk) & Item(Banana) & Item(Drill)') def socks_and_shoes(): @@ -589,6 +645,79 @@ def h(self, subgoal): return float('inf') +def CSPlan(planning_problem, solution_length, CSP_solver=ac_search_solver, arc_heuristic=sat_up): + """ + Planning as Constraint Satisfaction Problem [Section 10.4.3] + """ + + def st(var, stage): + """Returns a string for the var-stage pair that can be used as a variable""" + return str(var) + "_" + str(stage) + + def if_(v1, v2): + """If the second argument is v2, the first argument must be v1""" + + def if_fun(x1, x2): + return x1 == v1 if x2 == v2 else True + + if_fun.__name__ = "if the second argument is " + str(v2) + " then the first argument is " + str(v1) + " " + return if_fun + + def eq_if_not_in_(actset): + """First and third arguments are equal if action is not in actset""" + + def eq_if_not_in(x1, a, x2): + return x1 == x2 if a not in actset else True + + eq_if_not_in.__name__ = "first and third arguments are equal if action is not in " + str(actset) + " " + return eq_if_not_in + + expanded_actions = planning_problem.expand_actions() + fluent_values = planning_problem.expand_fluents() + for horizon in range(solution_length): + act_vars = [st('action', stage) for stage in range(horizon + 1)] + domains = {av: list(map(lambda action: expr(str(action)), expanded_actions)) for av in act_vars} + domains.update({st(var, stage): {True, False} for var in fluent_values for stage in range(horizon + 2)}) + # initial state constraints + constraints = [Constraint((st(var, 0),), is_(val)) + for (var, val) in {expr(str(fluent).replace('Not', '')): + True if fluent.op[:3] != 'Not' else False + for fluent in planning_problem.initial}.items()] + constraints += [Constraint((st(var, 0),), is_(False)) + for var in {expr(str(fluent).replace('Not', '')) + for fluent in fluent_values if fluent not in planning_problem.initial}] + # goal state constraints + constraints += [Constraint((st(var, horizon + 1),), is_(val)) + for (var, val) in {expr(str(fluent).replace('Not', '')): + True if fluent.op[:3] != 'Not' else False + for fluent in planning_problem.goals}.items()] + # precondition constraints + constraints += [Constraint((st(var, stage), st('action', stage)), if_(val, act)) + # st(var, stage) == val if st('action', stage) == act + for act, strps in {expr(str(action)): action for action in expanded_actions}.items() + for var, val in {expr(str(fluent).replace('Not', '')): + True if fluent.op[:3] != 'Not' else False + for fluent in strps.precond}.items() + for stage in range(horizon + 1)] + # effect constraints + constraints += [Constraint((st(var, stage + 1), st('action', stage)), if_(val, act)) + # st(var, stage + 1) == val if st('action', stage) == act + for act, strps in {expr(str(action)): action for action in expanded_actions}.items() + for var, val in {expr(str(fluent).replace('Not', '')): True if fluent.op[:3] != 'Not' else False + for fluent in strps.effect}.items() + for stage in range(horizon + 1)] + # frame constraints + constraints += [Constraint((st(var, stage), st('action', stage), st(var, stage + 1)), + eq_if_not_in_(set(map(lambda action: expr(str(action)), + {act for act in expanded_actions if var in act.effect + or Expr('Not' + var.op, *var.args) in act.effect})))) + for var in fluent_values for stage in range(horizon + 1)] + csp = NaryCSP(domains, constraints) + sol = CSP_solver(csp, arc_heuristic=arc_heuristic) + if sol: + return [sol[a] for a in act_vars] + + def SATPlan(planning_problem, solution_length, SAT_solver=dpll_satisfiable): """ Planning as Boolean satisfiability [Section 10.4.1] diff --git a/requirements.txt b/requirements.txt index 3d8754e71..ce8246bfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +pytest +sortedcontainers networkx==1.11 jupyter pandas diff --git a/tests/test_csp.py b/tests/test_csp.py index a7564a395..6aafa81c8 100644 --- a/tests/test_csp.py +++ b/tests/test_csp.py @@ -24,7 +24,7 @@ def test_csp_unassign(): assert var not in assignment -def test_csp_nconflits(): +def test_csp_nconflicts(): map_coloring_test = MapColoringCSP(list('RGB'), 'A: B C; B: C; C: ') assignment = {'A': 'R', 'B': 'G'} var = 'C' @@ -67,17 +67,16 @@ def test_csp_result(): def test_csp_goal_test(): map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') state = (('A', '1'), ('B', '3'), ('C', '2')) - assert map_coloring_test.goal_test(state) is True + assert map_coloring_test.goal_test(state) state = (('A', '1'), ('C', '2')) - assert map_coloring_test.goal_test(state) is False + assert not map_coloring_test.goal_test(state) def test_csp_support_pruning(): map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') map_coloring_test.support_pruning() - assert map_coloring_test.curr_domains == {'A': ['1', '2', '3'], 'B': ['1', '2', '3'], - 'C': ['1', '2', '3']} + assert map_coloring_test.curr_domains == {'A': ['1', '2', '3'], 'B': ['1', '2', '3'], 'C': ['1', '2', '3']} def test_csp_suppose(): @@ -88,8 +87,7 @@ def test_csp_suppose(): removals = map_coloring_test.suppose(var, value) assert removals == [('A', '2'), ('A', '3')] - assert map_coloring_test.curr_domains == {'A': ['1'], 'B': ['1', '2', '3'], - 'C': ['1', '2', '3']} + assert map_coloring_test.curr_domains == {'A': ['1'], 'B': ['1', '2', '3'], 'C': ['1', '2', '3']} def test_csp_prune(): @@ -100,16 +98,14 @@ def test_csp_prune(): map_coloring_test.support_pruning() map_coloring_test.prune(var, value, removals) - assert map_coloring_test.curr_domains == {'A': ['1', '2'], 'B': ['1', '2', '3'], - 'C': ['1', '2', '3']} + assert map_coloring_test.curr_domains == {'A': ['1', '2'], 'B': ['1', '2', '3'], 'C': ['1', '2', '3']} assert removals is None map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') removals = [('A', '2')] map_coloring_test.support_pruning() map_coloring_test.prune(var, value, removals) - assert map_coloring_test.curr_domains == {'A': ['1', '2'], 'B': ['1', '2', '3'], - 'C': ['1', '2', '3']} + assert map_coloring_test.curr_domains == {'A': ['1', '2'], 'B': ['1', '2', '3'], 'C': ['1', '2', '3']} assert removals == [('A', '2'), ('A', '3')] @@ -125,9 +121,9 @@ def test_csp_choices(): assert map_coloring_test.choices(var) == ['1', '2'] -def test_csp_infer_assignement(): +def test_csp_infer_assignment(): map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') - map_coloring_test.infer_assignment() == {} + assert map_coloring_test.infer_assignment() == {} var = 'A' value = '3' @@ -135,7 +131,7 @@ def test_csp_infer_assignement(): value = '1' map_coloring_test.prune(var, value, None) - map_coloring_test.infer_assignment() == {'A': '2'} + assert map_coloring_test.infer_assignment() == {'A': '2'} def test_csp_restore(): @@ -145,8 +141,7 @@ def test_csp_restore(): map_coloring_test.restore(removals) - assert map_coloring_test.curr_domains == {'A': ['2', '3', '1'], 'B': ['1', '2', '3'], - 'C': ['2', '3']} + assert map_coloring_test.curr_domains == {'A': ['2', '3', '1'], 'B': ['1', '2', '3'], 'C': ['2', '3']} def test_csp_conflicted_vars(): @@ -181,43 +176,95 @@ def test_revise(): Xj = 'B' removals = [] - assert revise(csp, Xi, Xj, removals) is False + assert not revise(csp, Xi, Xj, removals) assert len(removals) == 0 domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) csp.support_pruning() - assert revise(csp, Xi, Xj, removals) is True + assert revise(csp, Xi, Xj, removals) assert removals == [('A', 1), ('A', 3)] def test_AC3(): neighbors = parse_neighbors('A: B; B: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} - constraints = lambda X, x, Y, y: x % 2 == 0 and (x + y) == 4 and y % 2 != 0 + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 and y % 2 != 0 removals = [] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - assert AC3(csp, removals=removals) is False + assert not AC3(csp, removals=removals) - constraints = lambda X, x, Y, y: (x % 2) == 0 and (x + y) == 4 + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 removals = [] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - assert AC3(csp, removals=removals) is True + assert AC3(csp, removals=removals) assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)]) domains = {'A': [2, 4], 'B': [3, 5]} - constraints = lambda X, x, Y, y: int(x) > int(y) + constraints = lambda X, x, Y, y: (X == 'A' and Y == 'B') or (X == 'B' and Y == 'A') and x > y removals = [] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assert AC3(csp, removals=removals) +def test_AC3b(): + neighbors = parse_neighbors('A: B; B: ') + domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 and y % 2 != 0 + removals = [] + + csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) + + assert not AC3b(csp, removals=removals) + + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 + removals = [] + csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) + + assert AC3b(csp, removals=removals) + assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or + removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)]) + + domains = {'A': [2, 4], 'B': [3, 5]} + constraints = lambda X, x, Y, y: (X == 'A' and Y == 'B') or (X == 'B' and Y == 'A') and x > y + removals = [] + csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) + + assert AC3b(csp, removals=removals) + + +def test_AC4(): + neighbors = parse_neighbors('A: B; B: ') + domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 and y % 2 != 0 + removals = [] + + csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) + + assert not AC4(csp, removals=removals) + + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 + removals = [] + csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) + + assert AC4(csp, removals=removals) + assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or + removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)]) + + domains = {'A': [2, 4], 'B': [3, 5]} + constraints = lambda X, x, Y, y: (X == 'A' and Y == 'B') or (X == 'B' and Y == 'A') and x > y + removals = [] + csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) + + assert AC4(csp, removals=removals) + + def test_first_unassigned_variable(): map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') assignment = {'A': '1', 'B': '2'} @@ -246,7 +293,7 @@ def test_num_legal_values(): def test_mrv(): neighbors = parse_neighbors('A: B; B: C; C: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [4], 'C': [0, 1, 2, 3, 4]} - constraints = lambda X, x, Y, y: x % 2 == 0 and (x + y) == 4 + constraints = lambda X, x, Y, y: x % 2 == 0 and x + y == 4 csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) assignment = {'A': 0} @@ -302,30 +349,29 @@ def test_forward_checking(): var = 'B' value = 3 assignment = {'A': 1, 'C': '3'} - assert forward_checking(csp, var, value, assignment, None) == True + assert forward_checking(csp, var, value, assignment, None) assert csp.curr_domains['A'] == A_curr_domains assert csp.curr_domains['C'] == C_curr_domains assignment = {'C': 3} - assert forward_checking(csp, var, value, assignment, None) == True + assert forward_checking(csp, var, value, assignment, None) assert csp.curr_domains['A'] == [1, 3] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) csp.support_pruning() assignment = {} - assert forward_checking(csp, var, value, assignment, None) == True + assert forward_checking(csp, var, value, assignment, None) assert csp.curr_domains['A'] == [1, 3] assert csp.curr_domains['C'] == [1, 3] csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4, 7], 'C': [0, 1, 2, 3, 4]} csp.support_pruning() value = 7 assignment = {} - assert forward_checking(csp, var, value, assignment, None) == False + assert not forward_checking(csp, var, value, assignment, None) assert (csp.curr_domains['A'] == [] or csp.curr_domains['C'] == []) @@ -333,12 +379,10 @@ def test_backtracking_search(): assert backtracking_search(australia_csp) assert backtracking_search(australia_csp, select_unassigned_variable=mrv) assert backtracking_search(australia_csp, order_domain_values=lcv) - assert backtracking_search(australia_csp, select_unassigned_variable=mrv, - order_domain_values=lcv) + assert backtracking_search(australia_csp, select_unassigned_variable=mrv, order_domain_values=lcv) assert backtracking_search(australia_csp, inference=forward_checking) assert backtracking_search(australia_csp, inference=mac) - assert backtracking_search(usa_csp, select_unassigned_variable=mrv, - order_domain_values=lcv, inference=mac) + assert backtracking_search(usa_csp, select_unassigned_variable=mrv, order_domain_values=lcv, inference=mac) def test_min_conflicts(): @@ -354,7 +398,7 @@ def test_min_conflicts(): assert min_conflicts(NQueensCSP(3), 1000) is None -def test_nqueens_csp(): +def test_nqueensCSP(): csp = NQueensCSP(8) assignment = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} @@ -378,7 +422,6 @@ def test_nqueens_csp(): assert 2 not in assignment assert 3 not in assignment - assignment = {} assignment = {0: 0, 1: 1, 2: 4, 3: 1, 4: 6} csp.assign(5, 7, assignment) assert len(assignment) == 6 @@ -421,7 +464,7 @@ def test_topological_sort(): Sort, Parents = topological_sort(australia_csp, root) assert Sort == ['NT', 'SA', 'Q', 'NSW', 'V', 'WA'] - assert Parents['NT'] == None + assert Parents['NT'] is None assert Parents['SA'] == 'NT' assert Parents['Q'] == 'SA' assert Parents['NSW'] == 'Q' @@ -437,9 +480,42 @@ def test_tree_csp_solver(): (tcs['NT'] == 'B' and tcs['WA'] == 'R' and tcs['Q'] == 'R' and tcs['NSW'] == 'B' and tcs['V'] == 'R') +def test_ac_solver(): + assert ac_solver(csp_crossword) == {'one_across': 'has', + 'one_down': 'hold', + 'two_down': 'syntax', + 'three_across': 'land', + 'four_across': 'ant'} or {'one_across': 'bus', + 'one_down': 'buys', + 'two_down': 'search', + 'three_across': 'year', + 'four_across': 'car'} + assert ac_solver(two_two_four) == {'T': 7, 'F': 1, 'W': 6, 'O': 5, 'U': 3, 'R': 0, 'C1': 1, 'C2': 1, 'C3': 1} or \ + {'T': 9, 'F': 1, 'W': 2, 'O': 8, 'U': 5, 'R': 6, 'C1': 1, 'C2': 0, 'C3': 1} + assert ac_solver(send_more_money) == {'S': 9, 'M': 1, 'E': 5, 'N': 6, 'D': 7, 'O': 0, 'R': 8, 'Y': 2, + 'C1': 1, 'C2': 1, 'C3': 0, 'C4': 1} + + +def test_ac_search_solver(): + assert ac_search_solver(csp_crossword) == {'one_across': 'has', + 'one_down': 'hold', + 'two_down': 'syntax', + 'three_across': 'land', + 'four_across': 'ant'} or {'one_across': 'bus', + 'one_down': 'buys', + 'two_down': 'search', + 'three_across': 'year', + 'four_across': 'car'} + assert ac_search_solver(two_two_four) == {'T': 7, 'F': 1, 'W': 6, 'O': 5, 'U': 3, 'R': 0, + 'C1': 1, 'C2': 1, 'C3': 1} or \ + {'T': 9, 'F': 1, 'W': 2, 'O': 8, 'U': 5, 'R': 6, 'C1': 1, 'C2': 0, 'C3': 1} + assert ac_search_solver(send_more_money) == {'S': 9, 'M': 1, 'E': 5, 'N': 6, 'D': 7, 'O': 0, 'R': 8, 'Y': 2, + 'C1': 1, 'C2': 1, 'C3': 0, 'C4': 1} + + def test_different_values_constraint(): - assert different_values_constraint('A', 1, 'B', 2) == True - assert different_values_constraint('A', 1, 'B', 1) == False + assert different_values_constraint('A', 1, 'B', 2) + assert not different_values_constraint('A', 1, 'B', 1) def test_flatten(): @@ -482,6 +558,7 @@ def test_make_arc_consistent(): assert make_arc_consistent(Xi, Xj, csp) == [0, 2, 4] + def test_assign_value(): neighbors = parse_neighbors('A: B; B: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} @@ -505,6 +582,7 @@ def test_assign_value(): assignment = {'A': 1} assert assign_value(Xi, Xj, csp, assignment) == 3 + def test_no_inference(): neighbors = parse_neighbors('A: B; B: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4, 5]} @@ -514,7 +592,7 @@ def test_no_inference(): var = 'B' value = 3 assignment = {'A': 1} - assert no_inference(csp, var, value, assignment, None) == True + assert no_inference(csp, var, value, assignment, None) def test_mac(): @@ -526,7 +604,7 @@ def test_mac(): assignment = {'A': 0} csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - assert mac(csp, var, value, assignment, None) == True + assert mac(csp, var, value, assignment, None) neighbors = parse_neighbors('A: B; B: ') domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]} @@ -536,29 +614,43 @@ def test_mac(): assignment = {'A': 1} csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - assert mac(csp, var, value, assignment, None) == False + assert not mac(csp, var, value, assignment, None) constraints = lambda X, x, Y, y: x % 2 != 0 and (x + y) == 6 and y % 2 != 0 csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - assert mac(csp, var, value, assignment, None) == True + assert mac(csp, var, value, assignment, None) + def test_queen_constraint(): - assert queen_constraint(0, 1, 0, 1) == True - assert queen_constraint(2, 1, 4, 2) == True - assert queen_constraint(2, 1, 3, 2) == False + assert queen_constraint(0, 1, 0, 1) + assert queen_constraint(2, 1, 4, 2) + assert not queen_constraint(2, 1, 3, 2) def test_zebra(): z = Zebra() - algorithm=min_conflicts -# would take very long + algorithm = min_conflicts + # would take very long ans = algorithm(z, max_steps=10000) - assert ans is None or ans == {'Red': 3, 'Yellow': 1, 'Blue': 2, 'Green': 5, 'Ivory': 4, 'Dog': 4, 'Fox': 1, 'Snails': 3, 'Horse': 2, 'Zebra': 5, 'OJ': 4, 'Tea': 2, 'Coffee': 5, 'Milk': 3, 'Water': 1, 'Englishman': 3, 'Spaniard': 4, 'Norwegian': 1, 'Ukranian': 2, 'Japanese': 5, 'Kools': 1, 'Chesterfields': 2, 'Winston': 3, 'LuckyStrike': 4, 'Parliaments': 5} - -# restrict search space - z.domains = {'Red': [3, 4], 'Yellow': [1, 2], 'Blue': [1, 2], 'Green': [4, 5], 'Ivory': [4, 5], 'Dog': [4, 5], 'Fox': [1, 2], 'Snails': [3], 'Horse': [2], 'Zebra': [5], 'OJ': [1, 2, 3, 4, 5], 'Tea': [1, 2, 3, 4, 5], 'Coffee': [1, 2, 3, 4, 5], 'Milk': [3], 'Water': [1, 2, 3, 4, 5], 'Englishman': [1, 2, 3, 4, 5], 'Spaniard': [1, 2, 3, 4, 5], 'Norwegian': [1], 'Ukranian': [1, 2, 3, 4, 5], 'Japanese': [1, 2, 3, 4, 5], 'Kools': [1, 2, 3, 4, 5], 'Chesterfields': [1, 2, 3, 4, 5], 'Winston': [1, 2, 3, 4, 5], 'LuckyStrike': [1, 2, 3, 4, 5], 'Parliaments': [1, 2, 3, 4, 5]} + assert ans is None or ans == {'Red': 3, 'Yellow': 1, 'Blue': 2, 'Green': 5, 'Ivory': 4, 'Dog': 4, 'Fox': 1, + 'Snails': 3, 'Horse': 2, 'Zebra': 5, 'OJ': 4, 'Tea': 2, 'Coffee': 5, 'Milk': 3, + 'Water': 1, 'Englishman': 3, 'Spaniard': 4, 'Norwegian': 1, 'Ukranian': 2, + 'Japanese': 5, 'Kools': 1, 'Chesterfields': 2, 'Winston': 3, 'LuckyStrike': 4, + 'Parliaments': 5} + + # restrict search space + z.domains = {'Red': [3, 4], 'Yellow': [1, 2], 'Blue': [1, 2], 'Green': [4, 5], 'Ivory': [4, 5], 'Dog': [4, 5], + 'Fox': [1, 2], 'Snails': [3], 'Horse': [2], 'Zebra': [5], 'OJ': [1, 2, 3, 4, 5], + 'Tea': [1, 2, 3, 4, 5], 'Coffee': [1, 2, 3, 4, 5], 'Milk': [3], 'Water': [1, 2, 3, 4, 5], + 'Englishman': [1, 2, 3, 4, 5], 'Spaniard': [1, 2, 3, 4, 5], 'Norwegian': [1], + 'Ukranian': [1, 2, 3, 4, 5], 'Japanese': [1, 2, 3, 4, 5], 'Kools': [1, 2, 3, 4, 5], + 'Chesterfields': [1, 2, 3, 4, 5], 'Winston': [1, 2, 3, 4, 5], 'LuckyStrike': [1, 2, 3, 4, 5], + 'Parliaments': [1, 2, 3, 4, 5]} ans = algorithm(z, max_steps=10000) - assert ans == {'Red': 3, 'Yellow': 1, 'Blue': 2, 'Green': 5, 'Ivory': 4, 'Dog': 4, 'Fox': 1, 'Snails': 3, 'Horse': 2, 'Zebra': 5, 'OJ': 4, 'Tea': 2, 'Coffee': 5, 'Milk': 3, 'Water': 1, 'Englishman': 3, 'Spaniard': 4, 'Norwegian': 1, 'Ukranian': 2, 'Japanese': 5, 'Kools': 1, 'Chesterfields': 2, 'Winston': 3, 'LuckyStrike': 4, 'Parliaments': 5} + assert ans == {'Red': 3, 'Yellow': 1, 'Blue': 2, 'Green': 5, 'Ivory': 4, 'Dog': 4, 'Fox': 1, 'Snails': 3, + 'Horse': 2, 'Zebra': 5, 'OJ': 4, 'Tea': 2, 'Coffee': 5, 'Milk': 3, 'Water': 1, 'Englishman': 3, + 'Spaniard': 4, 'Norwegian': 1, 'Ukranian': 2, 'Japanese': 5, 'Kools': 1, 'Chesterfields': 2, + 'Winston': 3, 'LuckyStrike': 4, 'Parliaments': 5} if __name__ == "__main__": diff --git a/tests/test_planning.py b/tests/test_planning.py index 3062621c1..416eff7ca 100644 --- a/tests/test_planning.py +++ b/tests/test_planning.py @@ -325,6 +325,51 @@ def test_backwardPlan(): expr('Buy(Milk, SM)')] +def test_CSPlan(): + spare_tire_solution = CSPlan(spare_tire(), 3) + assert expr('Remove(Flat, Axle)') in spare_tire_solution + assert expr('Remove(Spare, Trunk)') in spare_tire_solution + assert expr('PutOn(Spare, Axle)') in spare_tire_solution + + cake_solution = CSPlan(have_cake_and_eat_cake_too(), 2) + assert expr('Eat(Cake)') in cake_solution + assert expr('Bake(Cake)') in cake_solution + + air_cargo_solution = CSPlan(air_cargo(), 6) + assert air_cargo_solution == [expr('Load(C1, P1, SFO)'), + expr('Fly(P1, SFO, JFK)'), + expr('Unload(C1, P1, JFK)'), + expr('Load(C2, P1, JFK)'), + expr('Fly(P1, JFK, SFO)'), + expr('Unload(C2, P1, SFO)')] or [expr('Load(C1, P1, SFO)'), + expr('Fly(P1, SFO, JFK)'), + expr('Unload(C1, P1, JFK)'), + expr('Load(C2, P2, JFK)'), + expr('Fly(P2, JFK, SFO)'), + expr('Unload(C2, P2, SFO)')] + + sussman_anomaly_solution = CSPlan(three_block_tower(), 3) + assert expr('MoveToTable(C, A)') in sussman_anomaly_solution + assert expr('Move(B, Table, C)') in sussman_anomaly_solution + assert expr('Move(A, Table, B)') in sussman_anomaly_solution + + blocks_world_solution = CSPlan(simple_blocks_world(), 3) + assert expr('ToTable(A, B)') in blocks_world_solution + assert expr('FromTable(B, A)') in blocks_world_solution + assert expr('FromTable(C, B)') in blocks_world_solution + + shopping_problem_solution = CSPlan(shopping_problem(), 5) + assert shopping_problem_solution == [expr('Go(Home, SM)'), + expr('Buy(Banana, SM)'), + expr('Buy(Milk, SM)'), + expr('Go(SM, HW)'), + expr('Buy(Drill, HW)')] or [expr('Go(Home, HW)'), + expr('Buy(Drill, HW)'), + expr('Go(HW, SM)'), + expr('Buy(Banana, SM)'), + expr('Buy(Milk, SM)')] + + def test_SATPlan(): spare_tire_solution = SATPlan(spare_tire(), 3) assert expr('Remove(Flat, Axle)') in spare_tire_solution @@ -335,6 +380,11 @@ def test_SATPlan(): assert expr('Eat(Cake)') in cake_solution assert expr('Bake(Cake)') in cake_solution + sussman_anomaly_solution = SATPlan(three_block_tower(), 3) + assert expr('MoveToTable(C, A)') in sussman_anomaly_solution + assert expr('Move(B, Table, C)') in sussman_anomaly_solution + assert expr('Move(A, Table, B)') in sussman_anomaly_solution + blocks_world_solution = SATPlan(simple_blocks_world(), 3) assert expr('ToTable(A, B)') in blocks_world_solution assert expr('FromTable(B, A)') in blocks_world_solution @@ -372,8 +422,7 @@ def test_linearize_class(): [expr('Load(C2, P2, JFK)'), expr('Fly(P2, JFK, SFO)'), expr('Load(C1, P1, SFO)'), expr('Fly(P1, SFO, JFK)'), expr('Unload(C1, P1, JFK)'), expr('Unload(C2, P2, SFO)')], [expr('Load(C2, P2, JFK)'), expr('Fly(P2, JFK, SFO)'), expr('Load(C1, P1, SFO)'), expr('Fly(P1, SFO, JFK)'), - expr('Unload(C2, P2, SFO)'), expr('Unload(C1, P1, JFK)')] - ] + expr('Unload(C2, P2, SFO)'), expr('Unload(C1, P1, JFK)')]] assert Linearize(ac).execute() in possible_solutions ss = socks_and_shoes() @@ -382,18 +431,28 @@ def test_linearize_class(): [expr('RightSock'), expr('LeftSock'), expr('LeftShoe'), expr('RightShoe')], [expr('RightSock'), expr('LeftSock'), expr('RightShoe'), expr('LeftShoe')], [expr('LeftSock'), expr('LeftShoe'), expr('RightSock'), expr('RightShoe')], - [expr('RightSock'), expr('RightShoe'), expr('LeftSock'), expr('LeftShoe')] - ] + [expr('RightSock'), expr('RightShoe'), expr('LeftSock'), expr('LeftShoe')]] assert Linearize(ss).execute() in possible_solutions def test_expand_actions(): - assert len(spare_tire().expand_actions()) == 16 - assert len(air_cargo().expand_actions()) == 360 + assert len(spare_tire().expand_actions()) == 9 + assert len(air_cargo().expand_actions()) == 20 assert len(have_cake_and_eat_cake_too().expand_actions()) == 2 assert len(socks_and_shoes().expand_actions()) == 4 assert len(simple_blocks_world().expand_actions()) == 12 - assert len(three_block_tower().expand_actions()) == 36 + assert len(three_block_tower().expand_actions()) == 18 + assert len(shopping_problem().expand_actions()) == 12 + + +def test_expand_feats_values(): + assert len(spare_tire().expand_fluents()) == 10 + assert len(air_cargo().expand_fluents()) == 18 + assert len(have_cake_and_eat_cake_too().expand_fluents()) == 2 + assert len(socks_and_shoes().expand_fluents()) == 4 + assert len(simple_blocks_world().expand_fluents()) == 12 + assert len(three_block_tower().expand_fluents()) == 16 + assert len(shopping_problem().expand_fluents()) == 20 def test_find_open_precondition(): @@ -405,10 +464,10 @@ def test_find_open_precondition(): ss = socks_and_shoes() pop = PartialOrderPlanner(ss) - assert (pop.find_open_precondition()[0] == expr('LeftShoeOn') and pop.find_open_precondition()[2][ - 0].name == 'LeftShoe') or ( - pop.find_open_precondition()[0] == expr('RightShoeOn') and pop.find_open_precondition()[2][ - 0].name == 'RightShoe') + assert (pop.find_open_precondition()[0] == expr('LeftShoeOn') and + pop.find_open_precondition()[2][0].name == 'LeftShoe') or ( + pop.find_open_precondition()[0] == expr('RightShoeOn') and + pop.find_open_precondition()[2][0].name == 'RightShoe') assert pop.find_open_precondition()[1] == pop.finish cp = have_cake_and_eat_cake_too() diff --git a/tests/test_probability.py b/tests/test_probability.py index e4a83ae47..a5d301017 100644 --- a/tests/test_probability.py +++ b/tests/test_probability.py @@ -1,5 +1,3 @@ -import random - import pytest from probability import * @@ -12,7 +10,7 @@ def tests(): assert cpt.p(True, event) == 0.95 event = {'Burglary': False, 'Earthquake': True} assert cpt.p(False, event) == 0.71 - # #enumeration_ask('Earthquake', {}, burglary) + # enumeration_ask('Earthquake', {}, burglary) s = {'A': True, 'B': False, 'C': True, 'D': False} assert consistent_with(s, {}) @@ -166,10 +164,10 @@ def test_elemination_ask(): def test_prior_sample(): random.seed(42) all_obs = [prior_sample(burglary) for x in range(1000)] - john_calls_true = [observation for observation in all_obs if observation['JohnCalls'] == True] - mary_calls_true = [observation for observation in all_obs if observation['MaryCalls'] == True] - burglary_and_john = [observation for observation in john_calls_true if observation['Burglary'] == True] - burglary_and_mary = [observation for observation in mary_calls_true if observation['Burglary'] == True] + john_calls_true = [observation for observation in all_obs if observation['JohnCalls']] + mary_calls_true = [observation for observation in all_obs if observation['MaryCalls']] + burglary_and_john = [observation for observation in john_calls_true if observation['Burglary']] + burglary_and_mary = [observation for observation in mary_calls_true if observation['Burglary']] assert len(john_calls_true) / 1000 == 46 / 1000 assert len(mary_calls_true) / 1000 == 13 / 1000 assert len(burglary_and_john) / len(john_calls_true) == 1 / 46 @@ -179,10 +177,10 @@ def test_prior_sample(): def test_prior_sample2(): random.seed(128) all_obs = [prior_sample(sprinkler) for x in range(1000)] - rain_true = [observation for observation in all_obs if observation['Rain'] == True] - sprinkler_true = [observation for observation in all_obs if observation['Sprinkler'] == True] - rain_and_cloudy = [observation for observation in rain_true if observation['Cloudy'] == True] - sprinkler_and_cloudy = [observation for observation in sprinkler_true if observation['Cloudy'] == True] + rain_true = [observation for observation in all_obs if observation['Rain']] + sprinkler_true = [observation for observation in all_obs if observation['Sprinkler']] + rain_and_cloudy = [observation for observation in rain_true if observation['Cloudy']] + sprinkler_and_cloudy = [observation for observation in sprinkler_true if observation['Cloudy']] assert len(rain_true) / 1000 == 0.476 assert len(sprinkler_true) / 1000 == 0.291 assert len(rain_and_cloudy) / len(rain_true) == 376 / 476 @@ -275,14 +273,12 @@ def test_forward_backward(): umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor) umbrella_evidence = [T, T, F, T, T] - assert (rounder(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == - [[0.6469, 0.3531], [0.8673, 0.1327], [0.8204, 0.1796], [0.3075, 0.6925], - [0.8204, 0.1796], [0.8673, 0.1327]]) + assert rounder(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [ + [0.6469, 0.3531], [0.8673, 0.1327], [0.8204, 0.1796], [0.3075, 0.6925], [0.8204, 0.1796], [0.8673, 0.1327]] umbrella_evidence = [T, F, T, F, T] assert rounder(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [ - [0.5871, 0.4129], [0.7177, 0.2823], [0.2324, 0.7676], [0.6072, 0.3928], - [0.2324, 0.7676], [0.7177, 0.2823]] + [0.5871, 0.4129], [0.7177, 0.2823], [0.2324, 0.7676], [0.6072, 0.3928], [0.2324, 0.7676], [0.7177, 0.2823]] def test_viterbi(): @@ -292,12 +288,10 @@ def test_viterbi(): umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor) umbrella_evidence = [T, T, F, T, T] - assert (rounder(viterbi(umbrellaHMM, umbrella_evidence, umbrella_prior)) == - [0.8182, 0.5155, 0.1237, 0.0334, 0.0210]) + assert rounder(viterbi(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [0.8182, 0.5155, 0.1237, 0.0334, 0.0210] umbrella_evidence = [T, F, T, F, T] - assert (rounder(viterbi(umbrellaHMM, umbrella_evidence, umbrella_prior)) == - [0.8182, 0.1964, 0.053, 0.0154, 0.0042]) + assert rounder(viterbi(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [0.8182, 0.1964, 0.053, 0.0154, 0.0042] def test_fixed_lag_smoothing(): @@ -309,8 +303,7 @@ def test_fixed_lag_smoothing(): umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor) d = 2 - assert rounder(fixed_lag_smoothing(e_t, umbrellaHMM, d, - umbrella_evidence, t)) == [0.1111, 0.8889] + assert rounder(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.1111, 0.8889] d = 5 assert fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t) is None @@ -319,8 +312,7 @@ def test_fixed_lag_smoothing(): e_t = T d = 1 - assert rounder(fixed_lag_smoothing(e_t, umbrellaHMM, - d, umbrella_evidence, t)) == [0.9939, 0.0061] + assert rounder(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.9939, 0.0061] def test_particle_filtering(): @@ -352,7 +344,7 @@ def test_monte_carlo_localization(): def P_motion_sample(kin_state, v, w): """Sample from possible kinematic states. - Returns from a single element distribution (no uncertainity in motion)""" + Returns from a single element distribution (no uncertainty in motion)""" pos = kin_state[:2] orient = kin_state[2] @@ -398,8 +390,7 @@ def P_sensor(x, y): def test_gibbs_ask(): - possible_solutions = ['False: 0.16, True: 0.84', 'False: 0.17, True: 0.83', - 'False: 0.15, True: 0.85'] + possible_solutions = ['False: 0.16, True: 0.84', 'False: 0.17, True: 0.83', 'False: 0.15, True: 0.85'] g_solution = gibbs_ask('Cloudy', dict(Rain=True), sprinkler, 200).show_approx() assert g_solution in possible_solutions diff --git a/utils.py b/utils.py index d0fc7c23a..9db0c020c 100644 --- a/utils.py +++ b/utils.py @@ -86,6 +86,13 @@ def powerset(iterable): return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))[1:] +def extend(s, var, val): + """Copy dict s and extend it by setting var to val; return copy.""" + s2 = s.copy() + s2[var] = val + return s2 + + # ______________________________________________________________________________ # argmin and argmax