Understanding Backtracking
An overview
About
This post explains the backtracking algorithm and provides a simple implementation. The goal is to provide a quick reference for easy reviewing.
Backtracking is a brute-force approach that incrementally finds solutions starting from an initial state. Once exhausted, the algorithm moves back a step and chooses another candidate.
Table of Contents
Usecase
Design
Implementation
Complexity
Examples
The Usecase
Backtracking can be applied to several computational problems that have an initial state and some constraints. For example, a sudoku puzzle can be solved using this technique.
One consideration is that the algorithm should reject a candidate very close to the root for efficiency reasons.
The Design
Overview
The high-level overview is straightforward.
- There’s a — usually recursive — function that evaluates a candidate.
- If the candidate is rejected, the function stops proceeding along that branch and returns.
- If the candidate is accepted, the function saves it as a solution.
- Then, this function is called for each incremental candidate that appears from the current state, usually implemented with a for loop.
- After each iteration of the loop, the candidate is reset to its original state so a new branch can be traversed.
The pseudocode would look like this:
function(data, candidate, solutions):
if rejected(candidate):
return
if accepted(candidate):
solutions.add(candidate)
for c in get_next(data, candidate):
candidate = update(candidate, c)
function(data, c, solutions)
candidate.remove(c)
The Implementation
We’ll implement a solution for the following problem:
Given an array
nums
of distinct integers, return all the possible permutations. You can return the answer in any order.
Logically, this is straightforward. You swap digits at every pair of index, and do this recursively. This way, all possible permutations are generated.
The complete implementation
def find_permutations(self, nums: List[int]) -> List[List[int]]:
result = []
def util(nums, current_index, result):
if current_index == len(nums) - 1:
result.append(nums.copy())
return
for index in range(current_index, len(nums)):
nums[current_index], nums[index] = nums[index], nums[current_index]
util(nums, current_index + 1, result)
nums[current_index], nums[index] = nums[index], nums[current_index]
util(nums, 0, result)
return result
Complexity Analysis
Worst case: exponential.
Examples
These are some problems on Leetcode to give you a good idea of how to solve backtracking questions.
https://leetcode.com/problems/n-queens
https://leetcode.com/problems/letter-combinations-of-a-phone-number