Improve sudoku solver algo(backtracking)

This commit is contained in:
Sirin Puenggun 2023-07-07 23:01:09 +07:00
parent 2add2a44dc
commit 130f05aeaa
3 changed files with 89 additions and 54 deletions

View File

@ -1,54 +1,44 @@
class SudokuSolver:
@staticmethod
def solve(puzzle):
sudoku_dict = {}
r = 'ABCDEFGHI'
c = '123456789'
row, col = SudokuSolver.find_empty_cell(puzzle)
if row == -1 and col == -1:
return True
for num in range(1, 10):
if SudokuSolver.is_valid_number(puzzle, row, col, num):
puzzle[row][col] = num
if SudokuSolver.solve(puzzle):
return True
puzzle[row][col] = 0
return False
@staticmethod
def find_empty_cell(grid):
for i in range(9):
for j in range(9):
sudoku_dict[r[i]+c[j]] = str(puzzle[i][j]) if puzzle[i][j] != 0 else c
square = [[x+y for x in i for y in j] for i in ('ABC','DEF','GHI') for j in ('123','456','789')]
peers = {}
for key in sudoku_dict.keys():
value = [i for i in square if key in i][0]
row = [[x+y for x in i for y in j][0] for i in key[0] for j in c]
col = [[x+y for x in i for y in j][0] for i in r for j in key[1]]
peers[key] = set(x for x in value+row+col if x != key)
if grid[i][j] == 0:
return i, j
return -1, -1
@staticmethod
def is_valid_number(grid, row, col, num):
for i in range(9):
sudoku_dict = SudokuSolver.Check(sudoku_dict,peers)
sudoku_dict = SudokuSolver.search(sudoku_dict, peers)
solution = []
for i in r:
solution.append([])
for j in c:
solution[r.find(i)].append(int(sudoku_dict[i+j]))
return solution
@staticmethod
def Check(sudoku_dict, peers):
for k,v in sudoku_dict.items():
if len(v) == 1:
for s in peers[k]:
sudoku_dict[s] = sudoku_dict[s].replace(v,'')
if len(sudoku_dict[s])==0:
if grid[row][i] == num:
return False
return sudoku_dict
@staticmethod
def search(sudoku_dict,peers):
if SudokuSolver.Check(sudoku_dict,peers)==False:
for i in range(9):
if grid[i][col] == num:
return False
if all(len(sudoku_dict[s]) == 1 for s in sudoku_dict.keys()):
return sudoku_dict
n,s = min((len(sudoku_dict[s]), s) for s in sudoku_dict.keys() if len(sudoku_dict[s]) > 1)
res = []
for value in sudoku_dict[s]:
new_sudoku_dict = sudoku_dict.copy()
new_sudoku_dict[s] = value
ans = SudokuSolver.search(new_sudoku_dict, peers)
if ans:
res.append(ans)
if len(res) > 1:
raise Exception("Error")
elif len(res) == 1:
return res[0]
box_row = row - row % 3
box_col = col - col % 3
for i in range(3):
for j in range(3):
if grid[box_row + i][box_col + j] == num:
return False
return True

16
main.py
View File

@ -55,13 +55,13 @@ def check_sudoku_rule(table, row, col, value):
sudoku_table = create_sudoku_table()
try:
console = Console()
console.print("\n[bold yellow]Solving Sudoku...[/bold yellow]")
start = time.process_time()
solved_puzzle = SudokuSolver.solve(sudoku_table)
in_time = time.process_time() - start
console = Console()
console.print("\n[bold yellow]Solving Sudoku...[/bold yellow]")
start = time.process_time()
status = SudokuSolver.solve(sudoku_table)
in_time = time.process_time() - start
if status == True:
console.print(f"\n[bold green]Finish! in {in_time} [/bold green]")
print_sudoku_table(solved_puzzle, console, clear=False)
except Exception:
print_sudoku_table(sudoku_table, console, clear=False)
elif status == False:
console.print("\n[bold red]Failed to solve![/bold red]")

View File

@ -0,0 +1,45 @@
import sys
sys.path.append('..')
from SudokuSolver import SudokuSolver
def test_solve_sudoku():
# Test a valid Sudoku grid
grid = [
[5, 3, 0, 0, 7, 0, 0, 0, 0],
[6, 0, 0, 1, 9, 5, 0, 0, 0],
[0, 9, 8, 0, 0, 0, 0, 6, 0],
[8, 0, 0, 0, 6, 0, 0, 0, 3],
[4, 0, 0, 8, 0, 3, 0, 0, 1],
[7, 0, 0, 0, 2, 0, 0, 0, 6],
[0, 6, 0, 0, 0, 0, 2, 8, 0],
[0, 0, 0, 4, 1, 9, 0, 0, 5],
[0, 0, 0, 0, 8, 0, 0, 7, 9]
]
assert SudokuSolver.solve(grid)
assert is_valid_solution(grid)
# Test an invalid Sudoku grid
grid = [
[6, 3, 0, 0, 7, 0, 0, 0, 0],
[6, 0, 0, 1, 9, 5, 0, 0, 0],
[0, 9, 8, 0, 0, 0, 0, 6, 0],
[8, 0, 0, 0, 6, 0, 0, 0, 3],
[4, 0, 0, 8, 0, 3, 0, 0, 1],
[7, 0, 0, 0, 2, 0, 0, 0, 6],
[0, 6, 0, 0, 0, 0, 2, 8, 0],
[0, 0, 0, 4, 1, 9, 0, 0, 5],
[0, 0, 0, 0, 8, 0, 0, 0, 9]
]
assert not SudokuSolver.solve(grid)
def is_valid_solution(grid):
# Check if each row, column, and 3x3 box contains all numbers from 1 to 9
for i in range(9):
row_nums = set(grid[i])
col_nums = set(grid[j][i] for j in range(9))
box_nums = set(grid[i//3*3+j//3][i%3*3+j%3] for j in range(9))
if row_nums != set(range(1, 10)) or col_nums != set(range(1, 10)) or box_nums != set(range(1, 10)):
return False
return True