AcWing 93. 递归实现组合型枚举
题目描述
从 1~n 这 n 个整数中随机选出 m 个,输出所有可能的选择方案。
输入格式
两个整数 n,m ,在同一行用空格隔开。
输出格式
按照从小到大的顺序输出所有方案,每行1个。
首先,同一行内的数升序排列,相邻两个数用一个空格隔开。
其次,对于两个不同的行,对应下标的数一一比较,字典序较小的排在前面(例如1 3 5 7排在1 3 6 8前面)。
数据范围
n>0 ,
0≤m≤n ,
n+(n−m)≤25
输入样例:
5 3
输出样例:
1 2 3
1 2 4
1 2 5
1 3 4
1 3 5
1 4 5
2 3 4
2 3 5
2 4 5
3 4 5
思考题:如果要求使用非递归方法,该怎么做呢?
解法一,看成求部分全排列
类似于枚举全排列,只是全排列递归终点是数组中恰好有n个数,此时是数组中恰好有m个数结束
每次尝试在第cnt的位置上放未枚举过的数
剪枝优化:假设将当前位置之后的所有数全部加入数组中仍然不能达到m个数,则之后的情况不必枚举,直接结束。
C++代码实现
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 27;
bool st[N];//表示第i位是否选择过,true表示选择过
int num[N];//存储cnt个数
int n, m;
//回溯求解组合数,在1--beg-1中已经枚举了cnt个数
void dfs(int beg, int cnt) {
//剪枝
if(n-beg+1+cnt < m) return;
//递归结束,输出结果
if(cnt == m){//恰好m个输出
for(int i = 0; i < m; ++i) {
printf("%d ", num[i]);
}
puts("");
return ;
}
//未结束,继续
for(int i = beg; i <= n; ++i) {
if(!st[i]) {//枚举第cnt个数
//选择数i
st[i] = true;
num[cnt] = i;
dfs(i+1, cnt+1);
//不选择数i
st[i] = false;
//dfs(endNum+1, cnt);
}
}
}
int main() {
cin>>n>>m;
dfs(1, 0);
return 0;
}
解法二,看成求大小为m的子集
每个数只有两种选择,选或者不选
当集合中恰好有m个数就结束
C++代码实现
#include <iostream>
#include <cstdio>
using namespace std;
int n, m;
//回溯求解组合数,在0--beg-1中已经枚举了cnt个数
void dfs(int beg, int cnt, int st) {
//剪枝
if(n-beg+cnt < m) return;
//递归结束,输出结果
if(cnt == m){//恰好m个输出
for(int i = 0; i < n; ++i) {
if(st>>i&1) printf("%d ", i+1);
}
puts("");
return ;
}
//未结束,继续
//选择第beg个数
dfs(beg+1, cnt+1, st|(1<<beg));
//不选
dfs(beg+1, cnt, st);
}
int main() {
cin>>n>>m;
dfs(0, 0, 0);
return 0;
}
思考题:非递归做法
先枚举所有的集合,将元素个数为m的子集保存下来,根据字典序排序后输出
存在问题:按照什么顺序排序可以使按字典序输出呢?
①暴力比较做法:两个数从低位枚举,找到第一个不同的数,谁为1谁排在前面,枚举n位,时间复杂度O(n)
②位运算优化比较:考虑到寻找的第个不同的位,可以用异或求出所有不同的位,再对异或和用lowbit运算获取最低位1,相与不为0的即排在前面,时间复杂度O(1)
C++代码实现
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int low_bit(int x) {
return x&(-x);
}
//暴力比较做法:两个数从低位枚举,找到第一个不同的数,谁为1谁排在前面,枚举n位,时间复杂度O(n)
//位运算优化比较:考虑到寻找的第个不同的位,可以用异或求出所有不同的位,
//再对异或和用lowbit运算获取最低位1,相与不为0的即排在前面
bool cmp(int a, int b) {
int t = a^b;
t = low_bit(t);
return a&t;
}
vector<int> ans;
int main() {
int n, m;
scanf("%d%d", &n, &m);
//先找到所有长度为m的二进制集合存入ans中
for(int i = 1; i < 1<<n; ++i) {
int cnt = 0;
for(int j = 0; j < n; ++j)
if(i>>j&1) cnt++;
if(cnt == m) ans.push_back(i);
}
//对集合进行排序从而按字典序输出
//比较的原则是有低位的1排在前面
sort(ans.begin(), ans.end(), cmp);
for(int i = 0; i < ans.size(); ++i) {
for(int j = 0; j < n; ++j) {
if(ans[i]>>j&1) printf("%d ", j+1);
}
puts("");
}
return 0;
}