diff --git a/SudokuSolver.py b/SudokuSolver.py index 78d31eb..361d5f8 100644 --- a/SudokuSolver.py +++ b/SudokuSolver.py @@ -1,54 +1,44 @@ class SudokuSolver: @staticmethod def solve(puzzle): - sudoku_dict = {} - r = 'ABCDEFGHI' - c = '123456789' + row, col = SudokuSolver.find_empty_cell(puzzle) + if row == -1 and col == -1: + return True + + for num in range(1, 10): + if SudokuSolver.is_valid_number(puzzle, row, col, num): + puzzle[row][col] = num + + if SudokuSolver.solve(puzzle): + return True + + puzzle[row][col] = 0 + + return False + + @staticmethod + def find_empty_cell(grid): for i in range(9): for j in range(9): - sudoku_dict[r[i]+c[j]] = str(puzzle[i][j]) if puzzle[i][j] != 0 else c - square = [[x+y for x in i for y in j] for i in ('ABC','DEF','GHI') for j in ('123','456','789')] - peers = {} - for key in sudoku_dict.keys(): - value = [i for i in square if key in i][0] - row = [[x+y for x in i for y in j][0] for i in key[0] for j in c] - col = [[x+y for x in i for y in j][0] for i in r for j in key[1]] - peers[key] = set(x for x in value+row+col if x != key) + if grid[i][j] == 0: + return i, j + return -1, -1 + + @staticmethod + def is_valid_number(grid, row, col, num): for i in range(9): - sudoku_dict = SudokuSolver.Check(sudoku_dict,peers) - sudoku_dict = SudokuSolver.search(sudoku_dict, peers) - solution = [] - for i in r: - solution.append([]) - for j in c: - solution[r.find(i)].append(int(sudoku_dict[i+j])) - return solution + if grid[row][i] == num: + return False - @staticmethod - def Check(sudoku_dict, peers): - for k,v in sudoku_dict.items(): - if len(v) == 1: - for s in peers[k]: - sudoku_dict[s] = sudoku_dict[s].replace(v,'') - if len(sudoku_dict[s])==0: - return False - return sudoku_dict + for i in range(9): + if grid[i][col] == num: + return False - @staticmethod - def search(sudoku_dict,peers): - if SudokuSolver.Check(sudoku_dict,peers)==False: - return False - if all(len(sudoku_dict[s]) == 1 for s in sudoku_dict.keys()): - return sudoku_dict - n,s = min((len(sudoku_dict[s]), s) for s in sudoku_dict.keys() if len(sudoku_dict[s]) > 1) - res = [] - for value in sudoku_dict[s]: - new_sudoku_dict = sudoku_dict.copy() - new_sudoku_dict[s] = value - ans = SudokuSolver.search(new_sudoku_dict, peers) - if ans: - res.append(ans) - if len(res) > 1: - raise Exception("Error") - elif len(res) == 1: - return res[0] + box_row = row - row % 3 + box_col = col - col % 3 + for i in range(3): + for j in range(3): + if grid[box_row + i][box_col + j] == num: + return False + + return True \ No newline at end of file diff --git a/main.py b/main.py index 99c5fff..805b5c9 100644 --- a/main.py +++ b/main.py @@ -55,13 +55,13 @@ def check_sudoku_rule(table, row, col, value): sudoku_table = create_sudoku_table() -try: - console = Console() - console.print("\n[bold yellow]Solving Sudoku...[/bold yellow]") - start = time.process_time() - solved_puzzle = SudokuSolver.solve(sudoku_table) - in_time = time.process_time() - start +console = Console() +console.print("\n[bold yellow]Solving Sudoku...[/bold yellow]") +start = time.process_time() +status = SudokuSolver.solve(sudoku_table) +in_time = time.process_time() - start +if status == True: console.print(f"\n[bold green]Finish! in {in_time} [/bold green]") - print_sudoku_table(solved_puzzle, console, clear=False) -except Exception: + print_sudoku_table(sudoku_table, console, clear=False) +elif status == False: console.print("\n[bold red]Failed to solve![/bold red]") diff --git a/tests/test_SudokuSolver.py b/tests/test_SudokuSolver.py new file mode 100644 index 0000000..e797b47 --- /dev/null +++ b/tests/test_SudokuSolver.py @@ -0,0 +1,45 @@ +import sys +sys.path.append('..') + +from SudokuSolver import SudokuSolver + +def test_solve_sudoku(): + # Test a valid Sudoku grid + grid = [ + [5, 3, 0, 0, 7, 0, 0, 0, 0], + [6, 0, 0, 1, 9, 5, 0, 0, 0], + [0, 9, 8, 0, 0, 0, 0, 6, 0], + [8, 0, 0, 0, 6, 0, 0, 0, 3], + [4, 0, 0, 8, 0, 3, 0, 0, 1], + [7, 0, 0, 0, 2, 0, 0, 0, 6], + [0, 6, 0, 0, 0, 0, 2, 8, 0], + [0, 0, 0, 4, 1, 9, 0, 0, 5], + [0, 0, 0, 0, 8, 0, 0, 7, 9] + ] + assert SudokuSolver.solve(grid) + assert is_valid_solution(grid) + + # Test an invalid Sudoku grid + grid = [ + [6, 3, 0, 0, 7, 0, 0, 0, 0], + [6, 0, 0, 1, 9, 5, 0, 0, 0], + [0, 9, 8, 0, 0, 0, 0, 6, 0], + [8, 0, 0, 0, 6, 0, 0, 0, 3], + [4, 0, 0, 8, 0, 3, 0, 0, 1], + [7, 0, 0, 0, 2, 0, 0, 0, 6], + [0, 6, 0, 0, 0, 0, 2, 8, 0], + [0, 0, 0, 4, 1, 9, 0, 0, 5], + [0, 0, 0, 0, 8, 0, 0, 0, 9] + ] + assert not SudokuSolver.solve(grid) + +def is_valid_solution(grid): + # Check if each row, column, and 3x3 box contains all numbers from 1 to 9 + for i in range(9): + row_nums = set(grid[i]) + col_nums = set(grid[j][i] for j in range(9)) + box_nums = set(grid[i//3*3+j//3][i%3*3+j%3] for j in range(9)) + if row_nums != set(range(1, 10)) or col_nums != set(range(1, 10)) or box_nums != set(range(1, 10)): + return False + return True +