[Data structure and algorithm] The algorithm with the least number of complete squares of the composition and sum

1. Subject requirements

Given a positive integer n, find several perfect square numbers (such as 1, 4, 9, 16, …) so that their sum equals n. You need to minimize the number of perfect squares that make up the sum.

  • Example one
	输入: n = 12
	输出: 3 
	解释: 12 = 4 + 4 + 4
  • Example two
	输入: n = 13
	输出: 2
	解释: 13 = 4 + 9
Two, example algorithm
① Violent enumeration method [time limit exceeded]
  • This problem requires us to find the smallest number of numbers that can be combined from a complete square array. We rephrase the problem as:
    Given a list of complete squares and a positive integer n, find the combination of the complete square array to make n, and the solution in the combination is required to have the minimum number of complete squares.
  • Note: You can reuse the perfect square numbers in the list.
  • From the above description of this problem, it seems to be a combination problem. For this problem, an intuitive solution is to use brute force enumeration. We enumerate all possible combinations and find the number of perfect squares. The smallest one.
  • We can use the following formula to express this problem:
    numSquares(n) = min(numSquares(nk) + 1), ∀k ∈ square numbers
  • From the above formula, we can convert it to a recursive solution.
  • The example algorithm is as follows:
class Solution(object):
    def numSquares(self, n):
        square_nums = [i**2 for i in range(1, int(math.sqrt(n))+1)]

        def minNumSquares(k):
            """ recursive solution """
            # bottom cases: find a square number
            if k in square_nums:
                return 1
            min_num = float('inf')

            # Find the minimal value among all possible solutions
            for square in square_nums:
                if k < square:
                    break
                new_num = minNumSquares(k-square) + 1
                min_num = min(min_num, new_num)
            return min_num

        return minNumSquares(n)
  • The above solution can be applied to a small positive integer n. However, we will find that for medium-sized numbers (such as 55), we will soon encounter the problem of exceeding the time limit.
  • Simply put, a stack overflow may occur due to excessive recursion.
② Dynamic programming
  • The reason why the use of brute force enumeration will exceed the time limit is simple, because we repeatedly calculated the intermediate solution. Our previous formula is still valid, we just need a better way to implement this formula:
    numSquares(n) = min(numSquares(nk) + 1), ∀k∈square numbers
  • You may have noticed from the formula that this problem is similar to the Fibonacci number problem. Like Fibonacci numbers, we have several more effective methods to calculate the solution, rather than simple recursion.
  • One way to solve the problem of stack overflow in recursion is to use dynamic programming (DP) technology, which is based on the idea of ​​reusing the results of intermediate solutions to calculate the final solution.
  • To calculate the value of numSquares(n), first calculate all the values ​​before n, that is,
    numSquares(n−k) ∀k∈square numbers. If we already have a solution for the number n−k somewhere, then there is no need to use recursive calculations.
  • The algorithm idea is as follows:
    • Almost all dynamic programming solutions first create a one-dimensional or multi-dimensional array DP to store the value of the intermediate sub-solution, and usually the last value of the array represents the final solution. Note that we created a fictitious element dp[0]=0 to simplify the logic, which helps when the remainder (nk) happens to be a perfect square number.
    • We also need to pre-compute a list of perfect square numbers less than a given number n (ie square_nums).
    • In the main step, we loop from number 1 to n and calculate the solution for each number i (ie numSquares(i)). In each iteration, we store the result of numSquares(i) in dp[i].
    • At the end of the loop, we return the last element in the array as the result of the solution.
    • In the figure below, we demonstrate how to calculate the results of numSquares(4) and numSquares(5) corresponding to dp[4] and dp[5]. Insert picture description here
  • The following is a sample implementation, where the Python solution took about 3500 ms, which is faster than 50% of the submissions at that time. It should be noted that the following Python solutions are only applicable to Python2. For some unknown reason, Python3 takes longer to run the same code.
class Solution(object):
    def numSquares(self, n):
        """
        :type n: int
        :rtype: int
        """
        square_nums = [i**2 for i in range(0, int(math.sqrt(n))+1)]
        
        dp = [float('inf')] * (n+1)
        # bottom case
        dp[0] = 0
        
        for i in range(1, n+1):
            for square in square_nums:
                if i < square:
                    break
                dp[i] = min(dp[i], dp[i-square] + 1)
        
        return dp[-1]
  • Then the Java algorithm is as follows:
class Solution {
    
    

  public int numSquares(int n) {
    
    
    int dp[] = new int[n + 1];
    Arrays.fill(dp, Integer.MAX_VALUE);
    // bottom case
    dp[0] = 0;

    // pre-calculate the square numbers.
    int max_square_index = (int) Math.sqrt(n) + 1;
    int square_nums[] = new int[max_square_index];
    for (int i = 1; i < max_square_index; ++i) {
    
    
      square_nums[i] = i * i;
    }

    for (int i = 1; i <= n; ++i) {
    
    
      for (int s = 1; s < max_square_index; ++s) {
    
    
        if (i < square_nums[s])
          break;
        dp[i] = Math.min(dp[i], dp[i - square_nums[s]] + 1);
      }
    }
    return dp[n];
  }
}
  • Complexity analysis
    • Time complexity: O(n⋅ n \sqrt{n}n ), there is a nested loop in the main step, where the outer loop is n iterations, and the inner loop requires at most n \sqrt{n}n Iteration.
    • Space complexity: O(n), using a one-dimensional array dp.
③ Greedy enumeration
  • The recursive solution provides a simple and intuitive way for us to understand the problem. We can still solve this problem with recursion. In order to improve the above violent enumeration solution, we can add greed to the recursion. We can reformat the enumeration as follows:
    starting from a number to a combination of multiple numbers, once we find a combination that can be combined into a given number n, then we can say that we have found the smallest combination, because the greedy one starts from small to large Enumeration combination.
  • For a better explanation, first define a function named is_divided_by(n, count), which returns a boolean value indicating whether the number n can be combined by a number count, instead of returning the combined function like the previous function numSquares(n) Exact size.
    numSquares(n) = argmin(is_divided_by(n,count)), count ∈ [1,2,…n]
  • Unlike the recursive function numSquare(n), the recursive process of is_divided_by(n, count) can be reduced to the bottom case (ie count==1) is faster.
  • The following is an example of the function is_divided_by(n, count), which decomposes the input n=5 and count=2. Through this reconstruction technique, we can significantly reduce the risk of stack overflow.
    Insert picture description here
  • Algorithm description
    • First prepare a list of perfect square numbers less than the given number n (called square_nums).
    • In the main loop, iterate the combined size (called count) from 1 to n, and check whether the number n can be divided by the combined sum, that is, is_divided_by(n, count).
    • The function is_divided_by(n, count) can be implemented in a recursive form, as mentioned above.
    • In the bottom example, count==1, just check whether the number n itself is a perfect square number. It can be checked in square_nums, that is, n ∈square_nums. If square_nums uses a collection data structure, you can get a faster running time than n == int(sqrt(n)) ^ 2.
  • Examples of Python algorithms are as follows:
class Solution:
    def numSquares(self, n):
        
        def is_divided_by(n, count):
            """
                return: true if "n" can be decomposed into "count" number of perfect square numbers.
                e.g. n=12, count=3:  true.
                     n=12, count=2:  false
            """
            if count == 1:
                return n in square_nums
            
            for k in square_nums:
                if is_divided_by(n - k, count - 1):
                    return True
            return False

        square_nums = set([i * i for i in range(1, int(n**0.5)+1)])
    
        for count in range(1, n+1):
            if is_divided_by(n, count):
                return count
    ```
    

 - Java 的算法如下:

```java
class Solution {
    
    
  Set<Integer> square_nums = new HashSet<Integer>();

  protected boolean is_divided_by(int n, int count) {
    
    
    if (count == 1) {
    
    
      return square_nums.contains(n);
    }

    for (Integer square : square_nums) {
    
    
      if (is_divided_by(n - square, count - 1)) {
    
    
        return true;
      }
    }
    return false;
  }

  public int numSquares(int n) {
    
    
    this.square_nums.clear();

    for (int i = 1; i * i <= n; ++i) {
    
    
      this.square_nums.add(i * i);
    }

    int count = 1;
    for (; count <= n; ++count) {
    
    
      if (is_divided_by(n, count))
        return count;
    }
    return count;
  }
}
④ Greedy + BFS (breadth first search)
  • As mentioned in the complexity analysis of the greedy algorithm above, the trajectory of the call stack forms an N-ary tree, where each node represents the call of the is_divided_by(n, count) function.
  • Based on the above ideas, the original problem can be rephrased as follows: Given an N-ary tree, where each node represents the combination of the remainder of the number n minus a perfect square number, our task is to find a node in the tree, the The node satisfies two conditions:
    (1) The value of the node (ie remainder) is also a perfect square number.
    (2) Among all nodes that satisfy the condition (1), the distance between the node and the root should be the smallest.

Insert picture description here

  • In the previous method 3, due to the greedy strategy of executing the call, the N-ary tree is actually constructed layer by layer from top to bottom. Traverse it in a BFS (breadth first search) manner. At each level of the N-ary tree, combinations of the same size are enumerated.
  • The order of traversal is BFS, not DFS (depth first search), because it will not explore any potential combinations that require more elements before exhausting all the possibilities of a fixed number of perfect squares to decompose the number n.
  • Algorithm analysis:
    • First, we prepare a list of perfect square numbers less than the given number n (ie square_nums).
    • Then create a queue traversal, this variable will hold the enumeration of all remaining items at each level.
    • In the main loop, iterate over the queue variable. In each iteration, check whether the remainder is a perfect square number. If the remainder is not a perfect square number, subtract it with one of the perfect squares to get a new remainder, and then add the new remainder to next_queue for the next level of iteration. Once you encounter the remainder of a perfect square number, you will jump out of the loop, which also means finding understanding.
  • In a typical BFS algorithm, the queue variable is usually an array or list type. However, the set type is used here to eliminate the redundancy of the remaining items in the same level. Facts have proved that this little trick can even increase the speed of operation by 5 times.
  • Take numSquares(7) as an example to illustrate the layout of the queue:
    Insert picture description here
  • The Python algorithm is as follows:
class Solution:
    def numSquares(self, n):

        # list of square numbers that are less than `n`
        square_nums = [i * i for i in range(1, int(n**0.5)+1)]
    
        level = 0
        queue = {
    
    n}
        while queue:
            level += 1
            #! Important: use set() instead of list() to eliminate the redundancy,
            # which would even provide a 5-times speedup, 200ms vs. 1000ms.
            next_queue = set()
            # construct the queue for the next level
            for remainder in queue:
                for square_num in square_nums:    
                    if remainder == square_num:
                        return level  # find the node!
                    elif remainder < square_num:
                        break
                    else:
                        next_queue.add(remainder - square_num)
            queue = next_queue
        return level
  • The Java algorithm is as follows:
class Solution {
    
    
  public int numSquares(int n) {
    
    

    ArrayList<Integer> square_nums = new ArrayList<Integer>();
    for (int i = 1; i * i <= n; ++i) {
    
    
      square_nums.add(i * i);
    }

    Set<Integer> queue = new HashSet<Integer>();
    queue.add(n);

    int level = 0;
    while (queue.size() > 0) {
    
    
      level += 1;
      Set<Integer> next_queue = new HashSet<Integer>();

      for (Integer remainder : queue) {
    
    
        for (Integer square : square_nums) {
    
    
          if (remainder.equals(square)) {
    
    
            return level;
          } else if (remainder < square) {
    
    
            break;
          } else {
    
    
            next_queue.add(remainder - square);
          }
        }
      }
      queue = next_queue;
    }
    return level;
  }
}

Guess you like

Origin blog.csdn.net/Forever_wj/article/details/109193291