bzoj5250 九省联考 秘密袭击【树上背包+拉格朗日插值+线段树合并】

解题思路:

第一个想法是枚举第 k 大的值,把大于的记为1,小于的记为0,问题就转化为树上联通块大小等于 k 的个数。

稍微转化一下,我们统计树上联通块第 k 大大等于 i 的个数,不妨记为 a i ,那么

a n s = i = 1 W i ( a i a i + 1 )

而因为这样计算每个大等于 i 的方案在 a 1 a i 中都会被算一次,恰好被计算了 i 次,所以

a n s = i = 1 W a i

a i 也可以等价转化成求联通块中大等于 i 大等于 k 个的方案数。
f u , i , j 表示表示 u 的子树中包含 u 且大等于 i 的个数为 j 的联通块个数,那么直接做树上背包就是 O ( n 3 ) 的复杂度,卡常据说可以过。

将第三维用生成函数的形式表达,即: F u , i = j = 0 n f u , i , j x j
考虑普通的树上背包dp用多项式乘法形式表现,即有转移:

F u , i = ( Π ( F s o n , i + 1 ) ) { 1 i > d u x i d u

再设 G u , i u 子树内 F v , i 之和,那么答案就是 i = 1 W G 1 , i k 次项之后所有系数之和。
如果暴力维护多项式乘法是 O ( n 4 ) O ( n 3 l o g n ) 的。

注意到答案最后是个 n 次多项式,考虑插 n + 1 x 的值进去,转化成点值来算,这样乘法就是 O ( 1 ) 的,而且转移 F 时连续一段 i 乘的是相同的值,所以可以用线段树维护。

转移流程大概如下:

  • ( f , g ) = ( 1 , 0 ) //初始化,线段树整体覆盖
  • ( f , g ) = ( f ( f v + 1 ) , g + g v ) //线段树合并
  • ( f , g ) = ( f x 0 , g ) //线段树区间修改
  • ( f , g ) = ( f , g + f ) //线段树整体修改

最后 f , g 会变为 a f + b , c f + d 的形式,类似线段树维护乘法和加法,我们维护 a , b , c , d 四个值的转移。

最后 i = 1 W G 1 , i 也是个多项式,所以我们只需要把点值对应项相加后求一遍系数就行了。

时间复杂度是 O ( n 2 l o g W ) 的,然而没有比暴力快……汗

#include<bits/stdc++.h>
using namespace std;
int getint()
{
    int i=0,f=1;char c;
    for(c=getchar();c!='-'&&(c<'0'||c>'9');c=getchar());
    if(c=='-')f=-1,c=getchar();
    for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
    return i*f;
}
const int N=2005,mod=64123;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
int Pow(int x,int y)
{
    int res=1;
    for(;y;y>>=1,x=mul(x,x))
        if(y&1)res=mul(res,x);
    return res;
}
int n,k,W,d[N];vector<int>e[N];
int tot,pool_top,pool[N*50],rt[N],lc[N*50],rc[N*50];
int inv[N],yc[N],c[N],g[N],f[N];
struct data
{
    int a,b,c,d;
    data():a(1),b(0),c(0),d(0){}
    data(int _a,int _b,int _c,int _d):a(_a),b(_b),c(_c),d(_d){}
    inline friend data operator * (const data &a,const data &b)
    {
        return data(mul(a.a,b.a),
                    add(mul(b.a,a.b),b.b),
                    add(mul(b.c,a.a),a.c),
                    add(mul(b.c,a.b),add(a.d,b.d)));
    }
}tag[N*50];
inline int newnode()
{
    int x=pool_top?pool[pool_top--]:++tot;
    tag[x]=data(),lc[x]=rc[x]=0;
    return x;
}
void del(int &x)
{
    if(!x)return;
    del(lc[x]),del(rc[x]);
    pool[++pool_top]=x,x=0;
}
void pushdown(int x)
{
    if(!lc[x])lc[x]=newnode();
    if(!rc[x])rc[x]=newnode();
    tag[lc[x]]=tag[lc[x]]*tag[x];
    tag[rc[x]]=tag[rc[x]]*tag[x];
    tag[x]=data();
}
void modify(int &k,int l,int r,int x,int y,data tg)
{
    if(!k)k=newnode();
    if(x<=l&&r<=y){tag[k]=tag[k]*tg;return;}
    pushdown(k);int mid=l+r>>1;
    if(x<=mid)modify(lc[k],l,mid,x,y,tg);
    if(y>mid)modify(rc[k],mid+1,r,x,y,tg);
}
void merge(int &x,int &y)
{
    if(!x)swap(x,y);
    if(!y)return;
    if(!lc[x]&&!rc[x])swap(x,y);
    if(!lc[y]&&!rc[y])
    {
        tag[x].a=mul(tag[x].a,tag[y].b);
        tag[x].b=mul(tag[x].b,tag[y].b);
        tag[x].d=add(tag[x].d,tag[y].d);
        return;
    }
    pushdown(x),pushdown(y);
    merge(lc[x],lc[y]),merge(rc[x],rc[y]);
}
void Init(int u,int fa,int x0)
{
    modify(rt[u],1,W,1,W,data(0,1,0,0));
    for(int i=0;i<e[u].size();i++)
    {
        int v=e[u][i];
        if(v==fa)continue;
        Init(v,u,x0);
        merge(rt[u],rt[v]);
        del(rt[v]);
    }
    if(d[u])modify(rt[u],1,W,1,d[u],data(x0,0,0,0));
    modify(rt[u],1,W,1,W,data(1,0,1,0));
    modify(rt[u],1,W,1,W,data(1,1,0,0));
}
void get_val(int k,int l,int r,int i)
{
    if(l==r){yc[i]=add(yc[i],tag[k].d);return;}
    pushdown(k);
    int mid=l+r>>1;
    get_val(lc[k],l,mid,i),get_val(rc[k],mid+1,r,i);
}
void div(int *a,int *b,int x0)
{
    for(int i=0;i<=n+1;i++)c[i]=a[i];
    for(int i=n+1;i>=1;i--)
    {
        b[i-1]=c[i];
        c[i-1]=add(c[i-1],mul(c[i],x0)),c[i]=0;
    }
}
int get_ans()
{
    int ans=0;
    for(int i=1;i<=n+1;i++)inv[i]=Pow(i,mod-2);
    g[0]=1;
    for(int i=1;i<=n+1;i++)
        for(int j=n+1;j>=0;j--)
        {
            g[j]=mul(g[j],mod-i);
            if(j)g[j]=add(g[j],g[j-1]);
        }
    for(int i=1;i<=n+1;i++)
    {
        div(g,f,i);int res=0;
        for(int j=k;j<=n;j++)res=add(res,f[j]);
        for(int j=1;j<=n+1;j++) if(i!=j) 
            res=(i>j?mul(res,inv[i-j]):mul(res,mod-inv[j-i]));
        res=mul(res,yc[i]),ans=add(ans,res);    
    }       
    return ans;
}
int main()
{
    //freopen("lx.in","r",stdin);
    n=getint(),k=getint(),W=getint();
    for(int i=1;i<=n;i++)d[i]=getint();
    for(int i=1;i<n;i++)
    {
        int x=getint(),y=getint();
        e[x].push_back(y),e[y].push_back(x);
    }
    for(int i=1;i<=n+1;i++)
    {
         Init(1,0,i);
         get_val(rt[1],1,W,i);
         del(rt[1]);
    }
    cout<<get_ans();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/cdsszjj/article/details/79993647