「Luogu P5368 [PKUSC2018]真实排名」

PKUSC签到题

题目大意

给出一个长度为 \(N\) 的序列,序列中有 \(K\) 个数会乘二,对于每个数计算在乘二后大于等于这个数的个数与乘二前没有发生变化的方案数.

分析

思路很清晰,可以将答案分为两个部分计算

当前位置的数没有乘二时

当前位置没有乘二,所以所有大于等于自己的元素是否乘二每有影响,如果一个数小于这个数的一半(不可以等于)那么这个数如果乘二也不会产生影响.于是可以计算出大于等于这个数的个数 \(+\) 小于这个数一半的数的个数.接着只需要通过组合数就可以计算你出来了.

当前位置的数乘二时

用一张图来解释起来比较简单:(下图序列为 \(a=[1,3,3,4,6,7]\))

在这里插入图片描述

假如开始时的排完序后的序列是上面这个样子,对于第二个数大于等于它的数的个数为 \(5\) (包括自己).现在需要将它的高度翻倍:

在这里插入图片描述
可以发现大于等于它的数的个数只剩下了一个,为了保证大于等于这个数的个数不变,在当前数乘二后必须将大于等于这个数小于这个数乘二的数都乘上二.

在这里插入图片描述

虽然最后的序列可能不再是有序的,但是对于这个数大于等于它的数的个数没有改变,可以发现原来小于它的数这时如果乘二并不会影响,大于这个数两倍的数乘二自然也不会产生影响啦,于是又可以通到组合数直接计算了.

组合数部分(会的可以跳过)

看到绝大多数的题解都不会涉及这部分内容,但是为了保证大多数人可以看懂,还是来写一下.

\(N\) 个数中选 \(M\) 个数的方案数: \(C_{M}^{N} = \frac{N!}{M!(N-M)!}\) 但是在本题中 \(N,M\) 都是 \(1 \times 10^5\) 级别自然不可以暴力计算,可以发现在这个公式由三个阶乘组成,于是自然会想到预处理阶乘,然后计算逆元.这样就可以做到 \(O(N\log_2N)\) 预处理阶乘和逆元,\(O(1)\) 计算,但是\(O(N\log_2N)\)还是有点慢了(虽然在本题中应该可以过),在一些 \(2 \times 10^6 \leq N\) (差不多)时这个时间复杂度就很容易TLE了,\(a\) 的逆元可以理解为 \(\frac{1}{a}\) 所以说 \(\frac{1}{i!}=\frac{1}{(i+1)!} \times i\),这样就可以先处理出 \(N!\) 的逆元,接着只需要 \(O(N)\) 的时间复杂度就可以计算出逆元了.

代码实现

#include<bits/stdc++.h>
#define REP(i,first,last) for(int i=first;i<=last;++i)
#define DOW(i,first,last) for(int i=first;i>=last;--i)
using namespace std;
const int maxN=3e5+7;
const int mod=998244353;
int Mod(long long a)//写一个Mod函数
{
    a%=mod;
    a+=mod;
    a%=mod;
    return a;
}
long long Inv(long long a,int b=mod-2)//计算一个数的逆元,其实就是一个快速幂
{
    long long result=1;
    while(b)
    {
        if(b&1)
        {
            result*=a;
            result%=mod;
        }
        a*=a;
        a%=mod;
        b/=2;
    }
    return result;
}
long long fac[maxN];//计算阶乘
long long inv[maxN];//计算逆元
int N,M;
int arr[maxN*2];
int sor[maxN*3];//用来离散化
int sum[maxN*3];
int tot=0;
int k;
map<int,int>Hash;//用来离散化
//因为比较懒,于是就先了一颗权值线段树来维护
struct SegmentTree
{
    int sum;
}sgt[maxN*4];
//线段树标准define
#define LSON (now<<1)
#define RSON (now<<1|1)
#define MIDDLE ((left+right)>>1)
#define LEFT LSON,left,MIDDLE
#define RIGHT RSON,MIDDLE+1,right
void PushUp(int now)//合并左右子树的值
{
    sgt[now].sum=sgt[LSON].sum+sgt[RSON].sum;
}
void Build(int now=1,int left=1,int right=tot)//建树
{
    if(left==right)//叶节点直接赋值
    {
        sgt[now].sum=sum[left];
        return;
    }
    Build(LEFT);
    Build(RIGHT);
    PushUp(now);//合并
}
//本题不需要修改
int QueryBigger(int num,int now=1,int left=1,int right=tot)
//计算大于等于的数的个数
{
    if(num<=left)
    {
        return sgt[now].sum;
    }
    if(right<num)
    {
        return 0;
    }
    return QueryBigger(num,LEFT)+QueryBigger(num,RIGHT);
}
int QuerySmaller(int num,int now=1,int left=1,int right=tot)
//计算小于等于的数的个数
{
    if(right<=num)
    {
        return sgt[now].sum;
    }
    if(num<left)
    {
        return 0;
    }
    return QuerySmaller(num,LEFT)+QuerySmaller(num,RIGHT);
}
int QuerySmaller_(int num,int now=1,int left=1,int right=tot)
//计算小于的数的个数
{
    if(right<num)
    {
        return sgt[now].sum;
    }
    if(num<=left)
    {
        return 0;
    }
    return QuerySmaller_(num,LEFT)+QuerySmaller_(num,RIGHT);
}
//用完就清空define,防止与其他地方冲突
#undef LSON
#undef RSON
#undef MIDDLE
#undef LEFT
#undef RIGHT
int main()
{
    scanf("%d%d",&N,&k);
    int cnt=0;
    fac[1]=1;
    REP(i,2,N)//预处理阶乘
    {
        fac[i]=Mod(fac[i-1]*i);
    }
    inv[N]=Inv(fac[N]);//计算N!的逆元
    DOW(i,N-1,0)
    {
        inv[i]=Mod(inv[i+1]*(i+1));//通过优化以后的方法O(N)计算所有阶乘的逆元
    }
    REP(i,1,N)
    {
        scanf("%d",&arr[i]);//对于每个数它与它的两倍和一半会在操作中用到
        sor[++cnt]=arr[i];
        sor[++cnt]=arr[i]/2;
        sor[++cnt]=arr[i]*2;
    }
    //离散化部分
    sort(sor+1,sor+1+cnt);
    sor[0]=-114514233;
    REP(i,1,cnt)
    {
        if(sor[i]!=sor[i-1])
        {
            Hash[sor[i]]=++tot;
        }
    }
    REP(i,1,N)//将数放入
    {
        sum[Hash[arr[i]]]++;
    }
    Build();//建树
    long long answer;//计算答案
    int all;//当前数不乘二时有多少数乘二不会造成影响
    int must;//计算如果当前数乘二时有多少数必须乘二
    int p,q;//计算时需要用到的一些变量
    REP(i,1,N)
    {
        all=QueryBigger(Hash[arr[i]]);//大于等于自身的数的个数肯定不会影响
        if(arr[i]&1)all+=QuerySmaller(Hash[arr[i]/2]);//这里需要根据这个数奇偶性查询乘二后仍然小于这个数的个数
        else all+=QuerySmaller_(Hash[arr[i]/2]);
        all-=1;//将自己减去
        answer=Mod(Mod(fac[all]*inv[k])*inv[all-k]);//通过组合数公式计算方案数
        if(arr[i]==0)must=1;else//需要特判一下0
        must=QuerySmaller_(Hash[arr[i]*2])-QuerySmaller_(Hash[arr[i]]);//必须乘二的数的个数为小于这个数乘二的数的个数-小于这个数的数的个数
        if(must<=k)//如果可以做到全部乘二
        {
            p=N-must;//乘二不会造成影响的数的个数
            q=k-must;//还可以有几个数乘二
            answer+=Mod(Mod(fac[p]*inv[q])*inv[p-q]);//计算方案数
        }
        printf("%lld\n",answer%mod);//输出答案
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Sxy_Limit/p/12327199.html