codeforces960G Bandit Blues

Bandit Blues

题目描述
传送门:http://codeforces.com/contest/960/problem/G

题解

首先考虑n,A,B都<1000的版本,可以暴力dp解决。
枚举最高点位置,显然除了最高点外从左边能看到的A-1个点都在最高点左边,从右边能看到的B-1个点都在最高点右边。
对于任意一边,f[i][j]表示总共i个点有j个能从左边被看到。
考虑转移,每次增加一个点,我们强制让它最小。这个最小的点要么放最左边要么放在中间并且不被看见()
得到dp式:f[i][j]=f[i-1][j-1]+f[i-1][j]*(i-1)
这东西其实就是第一类斯特林数的递推式。
ans=Σf[i][A]*f[n-i+1][B]*C(n-1,A-1)=f[n-1][A+B-2]*C(A+B-2,A-1)
这个等式大概就是,先拿掉最大的,然后选A+B-2个让它们被看见,然后选A-1个放左边。
可以证明两者等价。
然后我们需要快速地求f[n-1][A+B-2],显然不能用递推式暴力推。
根据无符号斯特林数的定义,x*(x+1)(x+2)(x+3)……(x+n-1)=Σf[n][i]*x^i
把那个阶乘项用分治fft乘起来,找到第A+B-2项就好了。

代码

#include<bits/stdc++.h>
#define N 300005
#define ll long long
#define mod 998244353
using namespace std;
int n,A,B,L,len,rev[N];
ll t[N],s1[N],s2[N],ans[N];
vector<ll>s;

ll Pow(ll a,ll b)
{
  ll res=1;
  while(b)
  {
    if(b&1)res=res*a%mod;
    a=a*a%mod;b>>=1;
  }
  return res;
}

void NTT(ll *x,int n,int inv)
{
  for(int i=0;i<n;i++)t[rev[i]]=x[i];
  for(int i=0;i<n;i++)x[i]=t[i];
  for(int i=1,d=2;i<=L;i++,d<<=1)
  {
    ll w0=Pow(3,(mod-1)/d),u,v,w;//原根是3
    for(int j=0,k;j<n;j+=d)
      for(k=j,w=1;k<j+(d>>1);k++,w=w*w0%mod)
      {
        u=x[k];v=x[k+(d>>1)]*w%mod;
        x[k]=(u+v)%mod;x[k+(d>>1)]=(u-v+mod)%mod;
      }
  }
  if(inv==-1)
  {
    ll y=Pow(n,mod-2);reverse(x+1,x+n);
    for(int i=0;i<n;i++)x[i]=x[i]*y%mod;
  }
}

void solve(int l,int r,vector<ll>&s)
{
  if(l==r){s.push_back(l);s.push_back(1);return;}
  int mid=l+r>>1,lx=0,rx=0;vector<ll>t1,t2;
  solve(l,mid,t1);solve(mid+1,r,t2);
  for(int i=0;i<t1.size();i++)s1[lx++]=t1[i];lx--;
  for(int i=0;i<t2.size();i++)s2[rx++]=t2[i];rx--;
  for(len=1,L=0;len<=lx+rx;len<<=1,L++);
  for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(L-1);
  for(int i=t1.size();i<len;i++)s1[i]=0;
  for(int i=t2.size();i<len;i++)s2[i]=0;
  NTT(s1,len,1);NTT(s2,len,1);
  for(int i=0;i<len;i++)ans[i]=s1[i]*s2[i]%mod;
  NTT(ans,len,-1);
  for(int i=0;i<=lx+rx;i++)s.push_back(ans[i]);
}

ll C(int n,int m)
{
  ll res=1;
  for(int i=m+1;i<=n;i++)res=res*i%mod;
  for(int i=2;i<=n-m;i++)res=res*Pow(i,mod-2)%mod;
  return res;
}

int main()
{
  scanf("%d%d%d",&n,&A,&B);
  if(!A||!B||A+B>n+1){printf("0\n");return 0;}
  if(n==1)
  {
    if(A==1&&B==1)printf("1\n");
    else printf("0\n");
    return 0;
  }
  solve(0,n-2,s);
  printf("%d\n",s[A+B-2]*C(A+B-2,A-1)%mod);
  return 0;
}

猜你喜欢

转载自blog.csdn.net/wcy_1122/article/details/79934714