Author: | Tom Hall |
---|---|
Date: | 11pm Tues Dec 4th |
+----------+----------+----------+ | A1 A2 A3 | A4 A5 A6 | A7 A8 A9 | | B1 B2 B3 | B4 B5 B6 | B7 B8 B9 | | C1 C2 C3 | C4 C5 C6 | C7 C8 C9 | +----------+----------+----------+ | D1 D2 D3 | D4 D5 D6 | D7 D8 D9 | | E1 E2 E3 | E4 E5 E6 | E7 E8 E9 | | F1 F2 F3 | F4 F5 F6 | F7 F8 F9 | +----------+----------+----------+ | G1 G2 G3 | G4 G5 G6 | G7 G8 G9 | | H1 H2 H3 | H4 H5 H6 | H7 H8 H9 | | I1 I2 I3 | I4 I5 I6 | I7 I8 I9 | +----------+----------+----------+
def cross(A, B): return [a+b for a in A for b in B] rows = 'ABCDEFGHI' cols = '123456789' digits = '123456789' squares = cross(rows, cols) unitlist = ([cross(rows, c) for c in cols] + [cross(r, cols) for r in rows] + [cross(rs, cs) for rs in ('ABC','DEF','GHI') for cs in ('123','456','789')]) units = dict((s, [u for u in unitlist if s in u]) for s in squares) peers = dict((s, set(s2 for u in units[s] for s2 in u if s2 != s)) for s in squares)
def parse_grid(grid): "Given a string of 81 digits (or . or 0 or -), return a dict of {cell:values}" grid = [c for c in grid if c in '0.-123456789'] values = dict((s, digits) for s in squares) ## To start, every square can be any digit for s,d in zip(squares, grid): if d in digits and not assign(values, s, d): return False return values
def printboard(values): "Used for debugging." width = 1+max(len(values[s]) for s in squares) line = '\n' + '+'.join(['-'*(width*3)]*3) for r in rows: print ''.join(values[r+c].center(width)+(c in '36' and '|' or '') for c in cols) + (r in 'CF' and line or '') print
>>> grid = """ 003020600 900305001 001806400 008102900 700000008 006708200 002609500 800203009 005010300""" >>> printboard(parse_grid(grid)) 4 8 3 |9 2 1 |6 5 7 9 6 7 |3 4 5 |8 2 1 2 5 1 |8 7 6 |4 9 3 ------+------+------ 5 4 8 |1 3 2 |9 7 6 7 2 9 |5 6 4 |1 3 8 1 3 6 |7 9 8 |2 4 5 ------+------+------ 3 7 2 |6 8 9 |5 1 4 8 1 4 |2 5 3 |7 6 9 6 9 5 |4 1 7 |3 8 2
def assign(values, s, d): "Eliminate all the other values (except d) from values[s] and propagate." if all(eliminate(values, s, d2) for d2 in values[s] if d2 != d): return values else: return False
eliminate does all the work
def eliminate(values, s, d): "Eliminate d from values[s]; propagate when values or places <= 2." if d not in values[s]: return values ## Already eliminated values[s] = values[s].replace(d,'') if len(values[s]) == 0: return False ## Contradiction: removed last value elif len(values[s]) == 1: ## If there is only one value (d2) left in square, remove it from peers d2, = values[s] if not all(eliminate(values, s2, d2) for s2 in peers[s]): return False ## Now check the places where d appears in the units of s for u in units[s]: dplaces = [s for s in u if d in values[s]] if len(dplaces) == 0: return False elif len(dplaces) == 1: # d can only be in one place in unit; assign it there if not assign(values, dplaces[0], d): return False return values
def search(values): "Using depth-first search and propagation, try all possible values." if values is False: return False ## Failed earlier if all(len(values[s]) == 1 for s in squares): return values ## Solved! ## Chose the unfilled square s with the fewest possibilities _,s = min((len(values[s]), s) for s in squares if len(values[s]) > 1) return some(search(assign(values.copy(), s, d)) for d in values[s]) def some(seq): for e in seq: if e: return e return False
Peter Norvig
www.norvig.com
## Solve Every Sudoku Puzzle ## Throughout this program we have: ## r is a row, e.g. 'A' ## c is a column, e.g. '3' ## s is a square, e.g. 'A3' ## d is a digit, e.g. '9' ## u is a unit, e.g. ['A1','B1','C1','D1','E1','F1','G1','H1','I1'] ## g is a grid, e.g. 81 non-blank chars, e.g. starting with '.18...7... ## values is a dict of possible values, e.g. {'A1':'123489', 'A2':'8', ...} def cross(A, B): return [a+b for a in A for b in B] rows = 'ABCDEFGHI' cols = '123456789' digits = '123456789' squares = cross(rows, cols) unitlist = ([cross(rows, c) for c in cols] + [cross(r, cols) for r in rows] + [cross(rs, cs) for rs in ('ABC','DEF','GHI') for cs in ('123','456','789')]) units = dict((s, [u for u in unitlist if s in u]) for s in squares) peers = dict((s, set(s2 for u in units[s] for s2 in u if s2 != s)) for s in squares) def search(values): "Using depth-first search and propagation, try all possible values." if values is False: return False ## Failed earlier if all(len(values[s]) == 1 for s in squares): return values ## Solved! ## Chose the unfilled square s with the fewest possibilities _,s = min((len(values[s]), s) for s in squares if len(values[s]) > 1) return some(search(assign(values.copy(), s, d)) for d in values[s]) def assign(values, s, d): "Eliminate all the other values (except d) from values[s] and propagate." if all(eliminate(values, s, d2) for d2 in values[s] if d2 != d): return values else: return False def eliminate(values, s, d): "Eliminate d from values[s]; propagate when values or places <= 2." if d not in values[s]: return values ## Already eliminated values[s] = values[s].replace(d,'') if len(values[s]) == 0: return False ## Contradiction: removed last value elif len(values[s]) == 1: ## If there is only one value (d2) left in square, remove it from peers d2, = values[s] if not all(eliminate(values, s2, d2) for s2 in peers[s]): return False ## Now check the places where d appears in the units of s for u in units[s]: dplaces = [s for s in u if d in values[s]] if len(dplaces) == 0: return False elif len(dplaces) == 1: # d can only be in one place in unit; assign it there if not assign(values, dplaces[0], d): return False return values def parse_grid(grid): "Given a string of 81 digits (or .0-), return a dict of {cell:values}" grid = [c for c in grid if c in '0.-123456789'] values = dict((s, digits) for s in squares) ## Each square can be any digit for s,d in zip(squares, grid): if d in digits and not assign(values, s, d): return False return values def solve_file(filename, sep='\n', action=lambda x: x): "Parse a file into a sequence of 81-char descriptions and solve them." results = [action(search(parse_grid(grid))) for grid in file(filename).read().strip().split(sep)] print "## Got %d out of %d" % ( sum((r is not False) for r in results), len(results)) return results def printboard(values): "Used for debugging." width = 1+max(len(values[s]) for s in squares) line = '\n' + '+'.join(['-'*(width*3)]*3) for r in rows: print ''.join(values[r+c].center(width)+(c in '36' and '|' or '') for c in cols) + (r in 'CF' and line or '') print return values def all(seq): for e in seq: if not e: return False return True def some(seq): for e in seq: if e: return e return False if __name__ == '__main__': solve_file("top10.txt", action=printboard)