Subversion Repositories programming

Rev

Rev 409 | Blame | Compare with Previous | Last modification | View Log | RSS feed

#!/usr/bin/env python

__author__    = "Ira W. Snyder (devel@irasnyder.com)"
__copyright__ = "Copyright (c) 2006, Ira W. Snyder (devel@irasnyder.com)"
__license__   = "GNU GPL v2 (or, at your option, any later version)"

from PyCompat import *

import sys
import copy
import time

#
# Domain class.
#
# This will hold the domain for each place in the SudokuPuzzle class.
#

class SudokuDomain (object):

        DEFAULT = [1, 2, 3, 4, 5, 6, 7, 8, 9]

        def __init__ (self, domain=None):
                self.domain = self.DEFAULT[:]
                self.used = False

                if domain:
                        self.domain = domain

        def __repr__ (self):
                if len(self.domain) == 0:
                        return 'E'
                elif len(self.domain) == 1:
                        return str(self.domain[0])
                else:
                        return ' '

        def remove_value (self, value, strict=False):
                if strict:
                        if value not in self.domain:
                                raise ValueError

                if value not in self.domain:
                        return False
                else:
                        self.domain = [i for i in self.domain if i != value]
                        return True

        def set_value (self, value):
                self.domain = [value]

        def is_singleton (self):
                if self.get_len() == 1:
                        return True

                return False

        def is_empty (self):
                if self.get_len () == 0:
                        return True

                return False

        def get_len (self):
                return len(self.domain)

        def get_value (self):
                """Only works if this is a singleton"""
                if not self.is_singleton ():
                        raise ValueError

                return self.domain[0]

#
# SudokuPuzzle class.
#
# This will hold all of the current domains for each position in a sudoku
# puzzle, and allow each value to be retrieved by its row,col pair.
# This access can be done by using SudokuPuzzle[0,0] to access the first
# element, and SudokuPuzzle[8,8] to access the last element.
#

class SudokuPuzzle (object):

        def __init__ (self, puzzle=None, printing_enabled=True):
                """Can possibly take an existing puzzle to set the state"""
                self.__state = []

                if puzzle:
                        self.__state = puzzle.__state[:]
                else:
                        for i in xrange(81):
                                self.__state.append (SudokuDomain())

                self.printing_enabled = printing_enabled

        def __getitem__ (self, key):
                row, col = key
                return self.__state [row*9+col]

        def __setitem__ (self, key, value):
                row, col = key
                self.__state [row*9+col] = value
                return value

        def __repr__ (self):
                s = ''

                for i in xrange (9):
                        if i % 3 == 0:
                                s += '=' * 25
                                s += '\n'

                        for j in xrange (9):
                                if j % 3 == 0:
                                        s += '| '

                                s += '%s ' % str(self[i,j])

                        s += '|\n'

                s += '=' * 25
                return s

        def __iter__ (self):
                return self.__state.__iter__

        def valid_index (self, index):
                if index < 0 or index > 8:
                        return False

                return True

        def get_row (self, row, col):
                """Return a list that represents the list that $row is a part of.
                This will exclude the element at $row."""
                if not self.valid_index (row) or not self.valid_index (col):
                        raise ValueError # Bad index

                li = []
                for i in xrange(9):
                        if i != col:
                                li.append (self[row,i])

                return li

        def get_col (self, row, col):
                """Return a list that represents the list that $col is a part of.
                This will exclude the element at $col."""
                if not self.valid_index (row) or not self.valid_index (col):
                        raise ValueError # Bad index

                li = []
                for i in xrange(9):
                        if i != row:
                                li.append (self[i,col])

                return li

        def get_upper_left (self, row, col):
                """Return the row and column of the upper left part of the small
                box that contains self[row,col]."""
                new_row = row / 3 * 3
                new_col = col / 3 * 3

                return (new_row, new_col)

        def get_small_square (self, row, col):
                """Return a list that represents the small square that (row, col) is a
                member of. This will exclude the element at (row, col)."""

                (ul_row, ul_col) = self.get_upper_left (row, col)
                li = []

                for i in xrange(ul_row, ul_row+3):
                        for j in xrange(ul_col, ul_col+3):
                                if not (i == row and j == col):
                                        li.append (self[i,j])

                return li

        def prune (self, row, col, value):
                """Remove all occurances of $value from all of the places
                it cannot be in sudoku for the element at (row, col)."""

                for e in self.get_row (row, col):
                        if e.remove_value (value):
                                self.print_domain_changed (e, value)

                for e in self.get_col (row, col):
                        if e.remove_value (value):
                                self.print_domain_changed (e, value)

                for e in self.get_small_square (row, col):
                        if e.remove_value (value):
                                self.print_domain_changed (e, value)

        def puzzle_is_solved (self):
                for i in xrange(9):
                        for j in xrange(9):
                                if not self[i,j].is_singleton ():
                                        return False

                return True

        def puzzle_is_failed (self):
                for i in xrange(9):
                        for j in xrange(9):
                                if self[i,j].is_empty ():
                                        return True

                return False

        def print_domain_changed (self, value, removed_val):
                if not self.printing_enabled:
                        return

                for i in xrange(9):
                        for j in xrange(9):
                                if self[i,j] is value:
                                        print 'removed %d from (%d, %d) -> %s' % \
                                                (removed_val, i, j, value.domain)

        def print_generic (self, s):
                if not self.printing_enabled:
                        return

                print s


def solve (puzzle):

        # Print the puzzle we're trying to solve
        puzzle.print_generic ('\nTrying to solve:')
        puzzle.print_generic ('%s\n' % str(puzzle))

        changed = True

        # The main Arc Consistency Algorithm implementation
        while changed:
                changed = False
                for i in xrange(9):
                        for j in xrange(9):
                                e = puzzle[i,j]
                                if e.is_singleton () and e.used == False:
                                        puzzle.prune (i, j, e.get_value ())
                                        e.used = True
                                        changed = True

                                        # Check if we failed during the last pruning pass
                                        if puzzle.puzzle_is_failed ():
                                                puzzle.print_generic ('Puzzle failed, null entry created!')
                                                return False

        # Check if we're finished with the puzzle
        if puzzle.puzzle_is_solved ():
                puzzle.print_generic ('Puzzle finished! wheee!')
                return puzzle

        ### Find the smallest node in the puzzle. The smallest node
        ### is the best cantidate to split.
        size = sys.maxint
        smallest_rc = (10, 10)

        for i in xrange(9):
                for j in xrange(9):
                        e = puzzle[i,j]
                        if not e.is_singleton ():
                                if e.get_len () < size:
                                        size = e.get_len ()
                                        smallest_rc = (i, j)

        ### SPLIT TIME!
        (r, c) = smallest_rc
        spl = puzzle[r,c].get_len ()

        lhalf = puzzle[r,c].domain[:spl/2]
        rhalf = puzzle[r,c].domain[spl/2:]

        # Split Message
        puzzle.print_generic ('splitting at %s (size=%d) -> %s and %s' % \
                        (smallest_rc, size, lhalf, rhalf))

        # Solve the "left" half
        lcopy = copy.deepcopy (puzzle)
        lcopy[r,c] = SudokuDomain (lhalf)
        leftsolve = solve(lcopy)

        # If it solved correctly, then return it back up
        if leftsolve:
                return leftsolve

        # Solve the "right" half
        rcopy = copy.deepcopy (puzzle)
        rcopy[r,c] = SudokuDomain (rhalf)
        rightsolve = solve (rcopy)

        # If it solved correctly, then return it back up
        if rightsolve:
                return rightsolve

        # Both splits at this level failed, time to work our way back up the tree
        puzzle.print_generic ('Both splits at this level failed, back up!')
        return False



def main ():
        s = SudokuPuzzle ()

        print 'Enter a row at a time. Use \'e\' for empty squares.'
        for i in xrange(9):
                e = raw_input('line %d: ' % i)
                temp = e.split ()

                count = 0
                for j in temp:
                        try:
                                s[i,count].set_value (int(j))
                        except:
                                pass

                        count += 1

        tstart = time.time ()
        solution = solve (s)
        tend = time.time ()
        print '\nThe solution is:'
        print str(solution)
        print
        print 'Took %s seconds to solve' % str(tend - tstart)

if __name__ == '__main__':
        main ()