bzoj 2616 SPOJ PERIODNI——笛卡尔树+树形DP

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=2616

把相同高度的连续一段合成一个位置(可能不需要?),用前缀和维护宽度。

然后每次找区间里最低的那个点(ST表)作为根,递归左右孩子,构建笛卡尔树。

dp[ cr ][ j ] 表示在 cr 的子树里选择 j 个点的方案数。

自己本来写的是同时枚举 cr 这个点、ls 、rs 各贡献了多少个车,结果TLE。

看看题解,发现这样比较好(至多 \( n^3 \) ),就是先 \( dp[ cr ][ j ] = \sum dp[ ls ][ k ] * dp[ rs ][ j-k ] ),然后再枚举 cr 的贡献,形如 \( dp[ cr ][ j ] = \sum dp[ cr ][ j-k ] * C_{h}^{k} * C_{w-(j-k)}^{k} * k! \) ,其中 w 表示 cr 这个点的宽,h 表示 cr 这个点的高。

注意那里还要乘一个 \( k! \) 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int Mn(int a,int b){return a<b?a:b;}
int Mx(int a,int b){return a>b?a:b;}

const int N=505,K=10,M=1e6+5,mod=1e9+7;
int upt(int x){if(x>=mod)x-=mod;if(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,ht[N],jc[M],jcn[M],dp[N][N];
int bin[K],lg[N],st[N][K],s[N];
struct Dt{
  int bh,mx;
  Dt(int b=0,int m=0):bh(b),mx(m) {}
};
void init(int mx)
{
  jc[0]=1;for(int i=1;i<=mx;i++)jc[i]=(ll)jc[i-1]*i%mod;
  jcn[mx]=pw(jc[mx],mod-2);
  for(int i=mx-1;i>=0;i--)jcn[i]=(ll)jcn[i+1]*(i+1)%mod;

  for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1;
  bin[0]=1;for(int i=1;i<=lg[n];i++)bin[i]=bin[i-1]<<1;
  for(int i=1;i<=n;i++)st[i][0]=i;
  for(int t=1;t<=lg[n];t++)
    for(int i=1;i+bin[t]-1<=n;i++)
      {
    int u=st[i][t-1], v=st[i+bin[t-1]][t-1];
    if(ht[u]<ht[v])st[i][t]=u;
    else st[i][t]=v;
      }
}
int C(int n,int m)
{
  if(n<m)return 0;//
  return (ll)jc[n]*jcn[m]%mod*jcn[n-m]%mod;
}
int get(int l,int r)
{
  int d=lg[r-l+1];
  int u=st[l][d], v=st[r-bin[d]+1][d];
  if(ht[u]<ht[v])return u; else return v;
}
Dt solve(int l,int r,int pr)
{
  if(l>r)return Dt(0,0);
  int cr=get(l,r), w=s[r]-s[l-1], h=ht[cr]-pr;
  Dt Ls=solve(l,cr-1,ht[cr]); int ls=Ls.bh,m1=Ls.mx;
  Dt Rs=solve(cr+1,r,ht[cr]); int rs=Rs.bh,m2=Rs.mx;
  for(int i=1,l1=m1+m2;i<=l1;i++)
    for(int j=Mx(0,i-m2),l2=Mn(i,m1);j<=l2;j++)
      {
    dp[cr][i]=(dp[cr][i]+(ll)dp[ls][j]*dp[rs][i-j])%mod;
      }
  int lm=Mn(h,w), mx=m1+m2+Mn(h,w-m1-m2); dp[cr][0]=1;
  for(int i=mx;i;i--)
    for(int j=1,l1=Mn(i,lm);j<=l1;j++)
      {
    dp[cr][i]=(dp[cr][i]+
           (ll)dp[cr][i-j]*C(h,j)%mod*C(w-i+j,j)%mod*jc[j])%mod;
      }
  return Dt(cr,mx);
  /*
  printf("(%d,%d)cr=%d w=%d h=%d\n",l,r,cr,w,h);
  printf("  ls=%d m1=%d rs=%d m2=%d\n",ls,m1,rs,m2);
  int mx=m1+m2+Mn(h,w-m1-m2);
  printf("  mx=%d\n",mx);
  for(int i=1;i<=mx;i++)
    {
      printf("  i=%d\n",i);
      for(int j1=0,l1=Mn(m1,i);j1<=l1;j1++)
    for(int j2=Mx(0,i-j1-Mn(h,w-j1)),l2=Mn(i-j1,m2);j2<=l2;j2++)
      {
        int ret=(ll)dp[ls][j1]*dp[rs][j2]%mod;
        int k=i-j1-j2;
        ret=(ll)ret*C(h,k)%mod*C(w-j1-j2,k)%mod*jc[k]%mod;//jc[k]
        dp[cr][i]=upt(dp[cr][i]+ret);
        printf("    j1=%d j2=%d k=%d (dp[%d]=%d)\n"
           ,j1,j2,k,i,dp[cr][i]);
      }
      printf("  dp[%d]=%d\n",i,dp[cr][i]);
    }
  dp[cr][0]=1; return Dt(cr,mx);
  */
}
int main()
{
  int tn,tm,mx=0;
  scanf("%d%d",&tn,&tm);
  for(int i=1,d,lst=0;i<=tn;i++,lst=d)
    {
      scanf("%d",&d); mx=Mx(mx,d);
      if(d!=lst) ht[++n]=d, s[n]=s[n-1]+1;
      else s[n]++;
    }
  init(mx); dp[0][0]=1;///
  Dt Rt=solve(1,n,0);
  printf("%d\n",dp[Rt.bh][tm]);
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/Narh/p/10427195.html