Rev 409 | Blame | Compare with Previous | Last modification | View Log | RSS feed
#!/usr/bin/env python
__author__ = "Ira W. Snyder (devel@irasnyder.com)"
__copyright__ = "Copyright (c) 2006, Ira W. Snyder (devel@irasnyder.com)"
__license__ = "GNU GPL v2 (or, at your option, any later version)"
from PyCompat import *
import sys
import copy
import time
#
# Domain class.
#
# This will hold the domain for each place in the SudokuPuzzle class.
#
class SudokuDomain (object):
DEFAULT = [1, 2, 3, 4, 5, 6, 7, 8, 9]
def __init__ (self, domain=None):
self.domain = self.DEFAULT[:]
self.used = False
if domain:
self.domain = domain
def __repr__ (self):
if len(self.domain) == 0:
return 'E'
elif len(self.domain) == 1:
return str(self.domain[0])
else:
return ' '
def remove_value (self, value, strict=False):
if strict:
if value not in self.domain:
raise ValueError
if value not in self.domain:
return False
else:
self.domain = [i for i in self.domain if i != value]
return True
def set_value (self, value):
self.domain = [value]
def is_singleton (self):
if self.get_len() == 1:
return True
return False
def is_empty (self):
if self.get_len () == 0:
return True
return False
def get_len (self):
return len(self.domain)
def get_value (self):
"""Only works if this is a singleton"""
if not self.is_singleton ():
raise ValueError
return self.domain[0]
#
# SudokuPuzzle class.
#
# This will hold all of the current domains for each position in a sudoku
# puzzle, and allow each value to be retrieved by its row,col pair.
# This access can be done by using SudokuPuzzle[0,0] to access the first
# element, and SudokuPuzzle[8,8] to access the last element.
#
class SudokuPuzzle (object):
def __init__ (self, puzzle=None, printing_enabled=True):
"""Can possibly take an existing puzzle to set the state"""
self.__state = []
if puzzle:
self.__state = puzzle.__state[:]
else:
for i in xrange(81):
self.__state.append (SudokuDomain())
self.printing_enabled = printing_enabled
def __getitem__ (self, key):
row, col = key
return self.__state [row*9+col]
def __setitem__ (self, key, value):
row, col = key
self.__state [row*9+col] = value
return value
def __repr__ (self):
s = ''
for i in xrange (9):
if i % 3 == 0:
s += '=' * 25
s += '\n'
for j in xrange (9):
if j % 3 == 0:
s += '| '
s += '%s ' % str(self[i,j])
s += '|\n'
s += '=' * 25
return s
def __iter__ (self):
return self.__state.__iter__
def valid_index (self, index):
if index < 0 or index > 8:
return False
return True
def get_row (self, row, col):
"""Return a list that represents the list that $row is a part of.
This will exclude the element at $row."""
if not self.valid_index (row) or not self.valid_index (col):
raise ValueError # Bad index
li = []
for i in xrange(9):
if i != col:
li.append (self[row,i])
return li
def get_col (self, row, col):
"""Return a list that represents the list that $col is a part of.
This will exclude the element at $col."""
if not self.valid_index (row) or not self.valid_index (col):
raise ValueError # Bad index
li = []
for i in xrange(9):
if i != row:
li.append (self[i,col])
return li
def get_upper_left (self, row, col):
"""Return the row and column of the upper left part of the small
box that contains self[row,col]."""
new_row = row / 3 * 3
new_col = col / 3 * 3
return (new_row, new_col)
def get_small_square (self, row, col):
"""Return a list that represents the small square that (row, col) is a
member of. This will exclude the element at (row, col)."""
(ul_row, ul_col) = self.get_upper_left (row, col)
li = []
for i in xrange(ul_row, ul_row+3):
for j in xrange(ul_col, ul_col+3):
if not (i == row and j == col):
li.append (self[i,j])
return li
def prune (self, row, col, value):
"""Remove all occurances of $value from all of the places
it cannot be in sudoku for the element at (row, col)."""
for e in self.get_row (row, col):
if e.remove_value (value):
self.print_domain_changed (e, value)
for e in self.get_col (row, col):
if e.remove_value (value):
self.print_domain_changed (e, value)
for e in self.get_small_square (row, col):
if e.remove_value (value):
self.print_domain_changed (e, value)
def puzzle_is_solved (self):
for i in xrange(9):
for j in xrange(9):
if not self[i,j].is_singleton ():
return False
return True
def puzzle_is_failed (self):
for i in xrange(9):
for j in xrange(9):
if self[i,j].is_empty ():
return True
return False
def print_domain_changed (self, value, removed_val):
if not self.printing_enabled:
return
for i in xrange(9):
for j in xrange(9):
if self[i,j] is value:
print 'removed %d from (%d, %d) -> %s' % \
(removed_val, i, j, value.domain)
def print_generic (self, s):
if not self.printing_enabled:
return
print s
def solve (puzzle):
# Print the puzzle we're trying to solve
puzzle.print_generic ('\nTrying to solve:')
puzzle.print_generic ('%s\n' % str(puzzle))
changed = True
# The main Arc Consistency Algorithm implementation
while changed:
changed = False
for i in xrange(9):
for j in xrange(9):
e = puzzle[i,j]
if e.is_singleton () and e.used == False:
puzzle.prune (i, j, e.get_value ())
e.used = True
changed = True
# Check if we failed during the last pruning pass
if puzzle.puzzle_is_failed ():
puzzle.print_generic ('Puzzle failed, null entry created!')
return False
# Check if we're finished with the puzzle
if puzzle.puzzle_is_solved ():
puzzle.print_generic ('Puzzle finished! wheee!')
return puzzle
### Find the smallest node in the puzzle. The smallest node
### is the best cantidate to split.
size = sys.maxint
smallest_rc = (10, 10)
for i in xrange(9):
for j in xrange(9):
e = puzzle[i,j]
if not e.is_singleton ():
if e.get_len () < size:
size = e.get_len ()
smallest_rc = (i, j)
### SPLIT TIME!
(r, c) = smallest_rc
spl = puzzle[r,c].get_len ()
lhalf = puzzle[r,c].domain[:spl/2]
rhalf = puzzle[r,c].domain[spl/2:]
# Split Message
puzzle.print_generic ('splitting at %s (size=%d) -> %s and %s' % \
(smallest_rc, size, lhalf, rhalf))
# Solve the "left" half
lcopy = copy.deepcopy (puzzle)
lcopy[r,c] = SudokuDomain (lhalf)
leftsolve = solve(lcopy)
# If it solved correctly, then return it back up
if leftsolve:
return leftsolve
# Solve the "right" half
rcopy = copy.deepcopy (puzzle)
rcopy[r,c] = SudokuDomain (rhalf)
rightsolve = solve (rcopy)
# If it solved correctly, then return it back up
if rightsolve:
return rightsolve
# Both splits at this level failed, time to work our way back up the tree
puzzle.print_generic ('Both splits at this level failed, back up!')
return False
def main ():
s = SudokuPuzzle ()
print 'Enter a row at a time. Use \'e\' for empty squares.'
for i in xrange(9):
e = raw_input('line %d: ' % i)
temp = e.split ()
count = 0
for j in temp:
try:
s[i,count].set_value (int(j))
except:
pass
count += 1
tstart = time.time ()
solution = solve (s)
tend = time.time ()
print '\nThe solution is:'
print str(solution)
print 'Took %s seconds to solve' % str(tend - tstart)
if __name__ == '__main__':
main ()