diff --git a/csp.py b/csp.py index 1e97d7780..8c5ecde3d 100644 --- a/csp.py +++ b/csp.py @@ -14,12 +14,13 @@ class CSP(search.Problem): """This class describes finite-domain Constraint Satisfaction Problems. A CSP is specified by the following inputs: - variables A list of variables; each is atomic (e.g. int or string). + variables A list of variables; each is atomic (e.g. int or string). domains A dict of {var:[possible_value, ...]} entries. neighbors A dict of {var:[var,...]} that for each variable lists the other variables that participate in constraints. constraints A function f(A, a, B, b) that returns true if neighbors A, B satisfy the constraint when they have values A=a, B=b + In the textbook and in most mathematical definitions, the constraints are specified as explicit pairs of allowable values, but the formulation here is easier to express and more compact for @@ -29,7 +30,7 @@ class CSP(search.Problem): problem, that's all there is. However, the class also supports data structures and methods that help you - solve CSPs by calling a search function on the CSP. Methods and slots are + solve CSPs by calling a search function on the CSP. Methods and slots are as follows, where the argument 'a' represents an assignment, which is a dict of {var:val} entries: assign(var, val, a) Assign a[var] = val; do other bookkeeping @@ -307,8 +308,9 @@ def tree_csp_solver(csp): """[Figure 6.11]""" assignment = {} root = csp.variables[0] - X, parent = topological_sort(csp.variables, root) - for Xj in reversed(X): + root = 'NT' + X, parent = topological_sort(csp, root) + for Xj in reversed(X[1:]): if not make_arc_consistent(parent[Xj], Xj, csp): return None for Xi in X: @@ -318,8 +320,43 @@ def tree_csp_solver(csp): return assignment -def topological_sort(xs, x): - raise NotImplementedError +def topological_sort(X, root): + """Returns the topological sort of X starting from the root. + + Input: + X is a list with the nodes of the graph + N is the dictionary with the neighbors of each node + root denotes the root of the graph. + + Output: + stack is a list with the nodes topologically sorted + parents is a dictionary pointing to each node's parent + + Other: + visited shows the state (visited - not visited) of nodes + + """ + nodes = X.variables + neighbors = X.neighbors + + visited = defaultdict(lambda: False) + + stack = [] + parents = {} + + build_topological(root, None, neighbors, visited, stack, parents) + return stack, parents + +def build_topological(node, parent, neighbors, visited, stack, parents): + """Builds the topological sort and the parents of each node in the graph""" + visited[node] = True + + for n in neighbors[node]: + if(not visited[n]): + build_topological(n, node, neighbors, visited, stack, parents) + + parents[node] = parent + stack.insert(0,node) def make_arc_consistent(Xj, Xk, csp): diff --git a/tests/test_csp.py b/tests/test_csp.py index 5bed85c05..803dede74 100644 --- a/tests/test_csp.py +++ b/tests/test_csp.py @@ -274,6 +274,18 @@ def test_universal_dict(): def test_parse_neighbours(): assert parse_neighbors('X: Y Z; Y: Z') == {'Y': ['X', 'Z'], 'X': ['Y', 'Z'], 'Z': ['X', 'Y']} +def test_topological_sort(): + root = 'NT' + Sort, Parents = topological_sort(australia,root) + + assert Sort == ['NT','SA','Q','NSW','V','WA'] + assert Parents['NT'] == None + assert Parents['SA'] == 'NT' + assert Parents['Q'] == 'SA' + assert Parents['NSW'] == 'Q' + assert Parents['V'] == 'NSW' + assert Parents['WA'] == 'SA' + if __name__ == "__main__": pytest.main()