题意: 给出 n 个数,求找出 k 个数的方法,满足这 k 个数的和最大。
思路:
- 先对数组进行降序排序。
- 显然对答案有贡献的是出现多次但又有一部分不在k的数,比如样例中的1有两个,但是我们之选1个,因此答案才为2。 (组合数部分复习之前博客)
代码实现:
#include<bits/stdc++.h>
#define endl '\n'
#define null NULL
#define ll long long
#define int long long
#define pii pair<int, int>
#define lowbit(x) (x &(-x))
#define ls(x) x<<1
#define rs(x) (x<<1+1)
#define me(ar) memset(ar, 0, sizeof ar)
#define mem(ar,num) memset(ar, num, sizeof ar)
#define rp(i, n) for(int i = 0, i < n; i ++)
#define rep(i, a, n) for(int i = a; i <= n; i ++)
#define pre(i, n, a) for(int i = n; i >= a; i --)
#define IOS ios::sync_with_stdio(0); cin.tie(0);cout.tie(0);
const int way[4][2] = {
{
1, 0}, {
-1, 0}, {
0, 1}, {
0, -1}};
using namespace std;
const int inf = 0x7fffffff;
const double PI = acos(-1.0);
const double eps = 1e-6;
const ll mod = 1e9 + 7;
const int N = 2e5 + 5;
const int M = 1005;
int t, n, k;
int a[N];
map<int, int> mp;
int fact[N], infact[N];
int qmi(int a, int k, int p)
{
int res = 1;
while(k){
if(k & 1) res = (ll)res * a % p;
a = (ll) a * a % p;
k >>= 1;
}
return res;
}
void init()
{
fact[0] = infact[0] = 1;
for(int i = 1; i < N; i ++){
//表示i的阶乘
fact[i] = (ll)fact[i - 1] * i % mod;
//表示i的阶乘的逆元
infact[i] = (ll)infact[i - 1] * qmi(i, mod - 2, mod) % mod;
}
}
int cmp(int x, int y){
return x > y;
}
signed main()
{
IOS;
init();
cin >> t;
while(t --){
cin >> n >> k;
mp.clear();
for(int i = 1; i <= n; i ++){
cin >> a[i];
mp[a[i]] ++;
}
if(n==k){
cout << 1 << endl;
continue;
}
sort(a+1, a+n+1, cmp);
int ans = 1;
for(int i = 1; i <= k; i ++){
if(mp[a[i]]==1) continue;
else{
if(i+mp[a[i]]-1<=k) i += mp[a[i]]-1;
else{
int v = k-i+1, w = mp[a[i]];
ans = ans*fact[w]*infact[v]%mod*infact[w-v]%mod;
i = k;
}
}
}
cout << ans << endl;
}
return 0;
}
代码优化: 既然只考虑边界k的数,那么我们直接锁定 a[k] 即可。
#include<bits/stdc++.h>
#define endl '\n'
#define null NULL
#define ll long long
#define int long long
#define pii pair<int, int>
#define lowbit(x) (x &(-x))
#define ls(x) x<<1
#define rs(x) (x<<1+1)
#define me(ar) memset(ar, 0, sizeof ar)
#define mem(ar,num) memset(ar, num, sizeof ar)
#define rp(i, n) for(int i = 0, i < n; i ++)
#define rep(i, a, n) for(int i = a; i <= n; i ++)
#define pre(i, n, a) for(int i = n; i >= a; i --)
#define IOS ios::sync_with_stdio(0); cin.tie(0);cout.tie(0);
const int way[4][2] = {
{
1, 0}, {
-1, 0}, {
0, 1}, {
0, -1}};
using namespace std;
const int inf = 0x7fffffff;
const double PI = acos(-1.0);
const double eps = 1e-6;
const ll mod = 1e9 + 7;
const int N = 2e5 + 5;
const int M = 1005;
int t, n, k;
int a[N];
map<int, int> mp;
int fact[N], infact[N];
int qmi(int a, int k, int p)
{
int res = 1;
while(k){
if(k & 1) res = (ll)res * a % p;
a = (ll) a * a % p;
k >>= 1;
}
return res;
}
void init()
{
fact[0] = infact[0] = 1;
for(int i = 1; i < N; i ++){
//表示i的阶乘
fact[i] = (ll)fact[i - 1] * i % mod;
//表示i的阶乘的逆元
infact[i] = (ll)infact[i - 1] * qmi(i, mod - 2, mod) % mod;
}
}
int cmp(int x, int y){
return x > y;
}
signed main()
{
IOS;
init();
cin >> t;
while(t --){
cin >> n >> k;
mp.clear();
for(int i = 1; i <= n; i ++){
cin >> a[i];
mp[a[i]] ++;
}
if(n==k){
cout << 1 << endl;
continue;
}
sort(a+1, a+n+1, cmp);
int ans = 1, pos;
if(a[k]==a[k+1]){
for(int i = k; ~i; i --){
if(a[i]!=a[k]){
pos = i;
break;
}
}
int v = k-pos, w = mp[a[k]];
ans = ans*fact[w]*infact[v]%mod*infact[w-v]%mod;
}
cout << ans << endl;
}
return 0;
}