【回溯】【leetcode】返回 1 ... n 中所有可能的 k 个数的组合

题目:

给定两个整数 n 和 k,返回 1 ... 中所有可能的 k 个数的组合。

例:

输入: n = 4, k = 2
输出:
[
  [2,4],
  [3,4],
  [2,3],
  [1,2],
  [1,3],
  [1,4],
]

来源:

77. 组合

解题思路:回溯

想用暴力搜索解决,但还是无从下手,这时我们采用回溯的办法解决。

回溯配合递归使用,递归前访问,递归后回溯。访问与回溯是一对相反操作。

代码一:

class Solution {
public:
    vector< vector<int> > result;
    vector<int> path;

    vector< vector<int> > combine(int n, int k) {
        go(n, k, 1);
        return result;
    }

    void go(int n, int k, int start) {
        if (path.size() == k) {
            // 保存结果
            result.push_back(path);
            return;
        }
        for (int i = start; i <= n; i++) {
            path.push_back(i); // 访问
            go(n, k, i+1); // 递归
            path.pop_back(); // 回溯
        }
    }
};

代码一暴力搜索全部可能的组合,包括path.size()<k的情况。对于这种情况,需要剪枝以提升代码效率。

例如,若集合[1,2,3,4],k=3,当start指向3时,后面只剩一个数字4了,已经不足了,此时就没必要对这种情况递归调用了。

剩余的个数:n - start

path中已有的个数:path.size()

当前start:1,(start指向的数字还没有进入path,也没有计算在剩余中)

所以当 剩余的个数 + path中已有的个数 + 当前start < k时,就退出循环,这就是剪枝条件:n - i + path.size() + 1 < k,

优化后的代码二:

class Solution {
public:
    vector< vector<int> > result;
    vector<int> path;

    vector< vector<int> > combine(int n, int k) {
        go(n, k, 1);
        return result;
    }

    void go(int n, int k, int start) {
        if (path.size() == k) {
            result.push_back(path);
            return;
        }
        for (int i = start; i <= n; i++) {
            if (n - i + path.size() + 1 < k) break;
            path.push_back(i);
            go(n, k, i+1);
            path.pop_back();
        }
    }
};

回溯+递归,写出来的代码简单易懂。这之前使用暴力写过一次代码,现在读起来自己都读不懂了,贴出来对比一下,见下面代码。

意思是先指定第一个结果,后面每个结果在前一个结果上+1,就像存在这么一个+1运算符:current = prev + 1。注意数字超了要进位。

/**
 * Return an array of arrays of size *returnSize.
 * The sizes of the arrays are returned as *returnColumnSizes array.
 * Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().
 */
int combine_size(int n, int k) {
    if (k > n / 2) k = n - k;
    int s = 1;
    for (int i = 1; i <= k; i++) {
        s *= n--;
        s /= i;
    }
    return s;
}

int** combine(int n, int k, int* returnSize, int** returnColumnSizes){
    int sz = combine_size(n, k); // 先计算空间大小
    *returnSize = sz;

    // 申请空间
    int **ret = (int**)malloc(sizeof(int*) * sz);
    int *sizes = (int*)malloc(sizeof(int) * sz);
    for (int i = 0; i < sz; i++) {
        int *t = (int*)malloc(sizeof(int) * (k+1));
        t[k] = n + 1;
        ret[i] = t;
        sizes[i] = k;
    }
    *returnColumnSizes = sizes;

    // 初始第一个结果
    for (int i = 0; i < k; i++) {
        ret[0][i] = i + 1;
    }
    int p = 1;
    while (p < sz) {
        // ret[p] = ret[p-1] + 1
        int *pre = ret[p-1];
        int *cur = ret[p];

        // p指向行,而pos指向列,从最后一列算起
        // n=6 k=3, 1,2,3,6 -> 1,2,4,6
        int pos = k - 1;
        while (pos >= 0) {
            if (pre[pos] < pre[pos+1] - 1) {
                break;
            }
            pos--;
        }
        for (int i = 0; i < pos; i++) {
            cur[i] = pre[i];
        }
        cur[pos] = pre[pos] + 1;
        for (int i = pos + 1; i < k; i++) {
            cur[i] = cur[i-1] + 1;
        }
        p++;
    }
    return ret;
}

猜你喜欢

转载自blog.csdn.net/hbuxiaoshe/article/details/114704584