@@ -98,8 +98,15 @@ def goal_test(self, state):
98
98
99
99
## These are for constraint propagation
100
100
101
+ def support_pruning (self ):
102
+ """Make sure we can prune values from domains. (We want to pay
103
+ for this only if we use it.)"""
104
+ if self .curr_domains is None :
105
+ self .curr_domains = dict ((v , self .domains [v ][:]) for v in self .vars )
106
+
101
107
def suppose (self , var , value ):
102
108
"Start accumulating inferences from assuming var=value."
109
+ self .support_pruning ()
103
110
removals = [(var , a ) for a in self .curr_domains [var ] if a != value ]
104
111
self .curr_domains [var ] = [value ]
105
112
return removals
@@ -118,6 +125,12 @@ def restore(self, removals):
118
125
for B , b in removals :
119
126
self .curr_domains [B ].append (b )
120
127
128
+ def infer_assignment (self ):
129
+ "Return the partial assignment implied by the current inferences."
130
+ self .support_pruning ()
131
+ return dict ((v , self .curr_domains [v ][0 ])
132
+ for v in self .vars if 1 == len (self .curr_domains [v ]))
133
+
121
134
## This is for min_conflicts search
122
135
123
136
def conflicted_vars (self , current ):
@@ -232,6 +245,7 @@ def AC3(csp, queue=None, removals=None):
232
245
"""[Fig. 5.7]"""
233
246
if queue is None :
234
247
queue = [(Xi , Xk ) for Xi in csp .vars for Xk in csp .neighbors [Xi ]]
248
+ csp .support_pruning ()
235
249
while queue :
236
250
(Xi , Xj ) = queue .pop ()
237
251
if remove_inconsistent_values (csp , Xi , Xj , removals ):
@@ -418,6 +432,73 @@ def display(self, assignment):
418
432
print str (self .nconflicts (var , val , assignment ))+ ch ,
419
433
print
420
434
435
+ #______________________________________________________________________________
436
+ # Sudoku
437
+
438
+ import itertools , re
439
+
440
+ def flatten (seqs ): return sum (seqs , [])
441
+
442
+ easy1 = '..3.2.6..9..3.5..1..18.64....81.29..7.......8..67.82....26.95..8..2.3..9..5.1.3..'
443
+ harder1 = '4173698.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......'
444
+
445
+ class Sudoku (CSP ):
446
+ """A Sudoku problem.
447
+ The box grid is a 3x3 array of boxes, each a 3x3 array of cells.
448
+ Each cell holds a digit in 1..9. In each box, all digits are
449
+ different; the same for each row and column as a 9x9 grid.
450
+ >>> e = Sudoku(easy1)
451
+ >>> e.display(e.infer_assignment())
452
+ . . 3 | . 2 . | 6 . .
453
+ 9 . . | 3 . 5 | . . 1
454
+ . . 1 | 8 . 6 | 4 . .
455
+ ------+-------+------
456
+ . . 8 | 1 . 2 | 9 . .
457
+ 7 . . | . . . | . . 8
458
+ . . 6 | 7 . 8 | 2 . .
459
+ ------+-------+------
460
+ . . 2 | 6 . 9 | 5 . .
461
+ 8 . . | 2 . 3 | . . 9
462
+ . . 5 | . 1 . | 3 . .
463
+ >>> AC3(e); e.display(e.infer_assignment())
464
+ 4 8 3 | 9 2 1 | 6 5 7
465
+ 9 6 7 | 3 4 5 | 8 2 1
466
+ 2 5 1 | 8 7 6 | 4 9 3
467
+ ------+-------+------
468
+ 5 4 8 | 1 3 2 | 9 7 6
469
+ 7 2 9 | 5 6 4 | 1 3 8
470
+ 1 3 6 | 7 9 8 | 2 4 5
471
+ ------+-------+------
472
+ 3 7 2 | 6 8 9 | 5 1 4
473
+ 8 1 4 | 2 5 3 | 7 6 9
474
+ 6 9 5 | 4 1 7 | 3 8 2
475
+ >>> h = Sudoku(harder1)
476
+ >>> None != backtracking_search(h, select_unassigned_variable=mrv, inference=forward_checking)
477
+ True
478
+ """
479
+ R3 = range (3 )
480
+ Cell = itertools .count ().next
481
+ bgrid = [[[[Cell () for x in R3 ] for y in R3 ] for bx in R3 ] for by in R3 ]
482
+ boxes = flatten ([map (flatten , brow ) for brow in bgrid ])
483
+ rows = flatten ([map (flatten , zip (* brow )) for brow in bgrid ])
484
+ units = map (set , boxes + rows + zip (* rows ))
485
+ neighbors = dict ([(v , set .union (* [u for u in units if v in u ]) - set ([v ]))
486
+ for v in flatten (rows )])
487
+
488
+ def __init__ (self , grid ):
489
+ squares = re .findall (r'\d|\.' , grid )
490
+ domains = dict ((var , [int (ch )] if ch .isdigit () else range (1 , 10 ))
491
+ for var , ch in zip (flatten (self .rows ), squares ))
492
+ CSP .__init__ (self , None , domains , self .neighbors ,
493
+ different_values_constraint )
494
+
495
+ def display (self , assignment ):
496
+ def show_box (box ): return [' ' .join (map (show_cell , row )) for row in box ]
497
+ def show_cell (cell ): return str (assignment .get (cell , '.' ))
498
+ def abut (lines1 , lines2 ): return map (' | ' .join , zip (lines1 , lines2 ))
499
+ print '\n ------+-------+------\n ' .join (
500
+ '\n ' .join (reduce (abut , map (show_box , brow ))) for brow in self .bgrid )
501
+
421
502
#______________________________________________________________________________
422
503
# The Zebra Puzzle
423
504
0 commit comments