Codeforces E. Advertising Agency (#697 Div.3) (思维 / 数学<组合数>)

传送门

题意: 给出 n 个数,求找出 k 个数的方法,满足这 k 个数的和最大。
在这里插入图片描述
思路:

  • 先对数组进行降序排序。
  • 显然对答案有贡献的是出现多次但又有一部分不在k的数,比如样例中的1有两个,但是我们之选1个,因此答案才为2。 (组合数部分复习之前博客)

代码实现:

#include<bits/stdc++.h>
#define endl '\n'
#define null NULL
#define ll long long
#define int long long
#define pii pair<int, int>
#define lowbit(x) (x &(-x))
#define ls(x) x<<1
#define rs(x) (x<<1+1)
#define me(ar) memset(ar, 0, sizeof ar)
#define mem(ar,num) memset(ar, num, sizeof ar)
#define rp(i, n) for(int i = 0, i < n; i ++)
#define rep(i, a, n) for(int i = a; i <= n; i ++)
#define pre(i, n, a) for(int i = n; i >= a; i --)
#define IOS ios::sync_with_stdio(0); cin.tie(0);cout.tie(0);
const int way[4][2] = {
    
    {
    
    1, 0}, {
    
    -1, 0}, {
    
    0, 1}, {
    
    0, -1}};
using namespace std;
const int  inf = 0x7fffffff;
const double PI = acos(-1.0);
const double eps = 1e-6;
const ll   mod = 1e9 + 7;
const int  N = 2e5 + 5;
const int M = 1005;

int t, n, k;
int a[N];
map<int, int> mp;

int fact[N], infact[N];

int qmi(int a, int k, int p)
{
    
    
    int res = 1;
    while(k){
    
    
        if(k & 1) res = (ll)res * a % p;
        a = (ll) a * a % p;
        k >>= 1;
    }
    return res;
}

void init()
{
    
    
    fact[0] = infact[0] = 1;
    for(int i = 1; i < N; i ++){
    
    
        //表示i的阶乘
        fact[i] = (ll)fact[i - 1] * i % mod;
        //表示i的阶乘的逆元
        infact[i] = (ll)infact[i - 1] * qmi(i, mod - 2, mod) % mod;
    }
}

int cmp(int x, int y){
    
    
    return x > y;
}

signed main()
{
    
    
    IOS;

    init();
    cin >> t;
    while(t --){
    
    
        cin >> n >> k;
        mp.clear();
        for(int i = 1; i <= n; i ++){
    
    
            cin >> a[i];
            mp[a[i]] ++;
        }
        if(n==k){
    
    
            cout << 1 << endl;
            continue;
        }
        sort(a+1, a+n+1, cmp);
        int ans = 1;
        for(int i = 1; i <= k; i ++){
    
    
            if(mp[a[i]]==1) continue;
            else{
    
    
                if(i+mp[a[i]]-1<=k) i += mp[a[i]]-1;
                else{
    
    
                    int v = k-i+1, w = mp[a[i]];
                    ans = ans*fact[w]*infact[v]%mod*infact[w-v]%mod;
                    i = k;
                }
            }
        }
        cout << ans << endl;
    }

    return 0;
}

代码优化: 既然只考虑边界k的数,那么我们直接锁定 a[k] 即可。

#include<bits/stdc++.h>
#define endl '\n'
#define null NULL
#define ll long long
#define int long long
#define pii pair<int, int>
#define lowbit(x) (x &(-x))
#define ls(x) x<<1
#define rs(x) (x<<1+1)
#define me(ar) memset(ar, 0, sizeof ar)
#define mem(ar,num) memset(ar, num, sizeof ar)
#define rp(i, n) for(int i = 0, i < n; i ++)
#define rep(i, a, n) for(int i = a; i <= n; i ++)
#define pre(i, n, a) for(int i = n; i >= a; i --)
#define IOS ios::sync_with_stdio(0); cin.tie(0);cout.tie(0);
const int way[4][2] = {
    
    {
    
    1, 0}, {
    
    -1, 0}, {
    
    0, 1}, {
    
    0, -1}};
using namespace std;
const int  inf = 0x7fffffff;
const double PI = acos(-1.0);
const double eps = 1e-6;
const ll   mod = 1e9 + 7;
const int  N = 2e5 + 5;
const int M = 1005;

int t, n, k;
int a[N];
map<int, int> mp;

int fact[N], infact[N];

int qmi(int a, int k, int p)
{
    
    
    int res = 1;
    while(k){
    
    
        if(k & 1) res = (ll)res * a % p;
        a = (ll) a * a % p;
        k >>= 1;
    }
    return res;
}

void init()
{
    
    
    fact[0] = infact[0] = 1;
    for(int i = 1; i < N; i ++){
    
    
        //表示i的阶乘
        fact[i] = (ll)fact[i - 1] * i % mod;
        //表示i的阶乘的逆元
        infact[i] = (ll)infact[i - 1] * qmi(i, mod - 2, mod) % mod;
    }
}

int cmp(int x, int y){
    
    
    return x > y;
}

signed main()
{
    
    
    IOS;

    init();
    cin >> t;
    while(t --){
    
    
        cin >> n >> k;
        mp.clear();
        for(int i = 1; i <= n; i ++){
    
    
            cin >> a[i];
            mp[a[i]] ++;
        }
        if(n==k){
    
    
            cout << 1 << endl;
            continue;
        }
        sort(a+1, a+n+1, cmp);
        int ans = 1, pos;
        if(a[k]==a[k+1]){
    
    
            for(int i = k; ~i; i --){
    
    
                if(a[i]!=a[k]){
    
    
                    pos = i;
                    break;
                }
            }
            int v = k-pos, w = mp[a[k]];
            ans = ans*fact[w]*infact[v]%mod*infact[w-v]%mod;
        }
        cout << ans << endl;
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/Satur9/article/details/113562042
今日推荐