LOJ #3044. 「ZJOI2019」Minimax 搜索 动态DP+概率

神仙题! 

我的动态DP写的好丑...

因为写的是 LCT 版,所以要对每个 splay 维护的链记录链底编号(因为动态 DP 中链底要特判)  

code: 

#include <bits/stdc++.h>   
#define ll long long 
#define N 300008    
#define mod 998244353 
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;     

int si[N],an[N],n,W,det,CUR;              
int dep[N],hd[N],to[N<<1],nex[N<<1],leaf[N],fa[N],inv[N],edges;    

void add(int u,int v) 
{ 
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;   
}  

inline int qpow(int x,int y) 
{
    int tmp=1;   
    for(;y;y>>=1,x=(ll)x*x%mod)  
        if(y&1) tmp=(ll)tmp*x%mod;   
    return tmp;     
}
inline int INV(int x) { return qpow(x,mod-2); }

// 特殊数字处理    
struct spe
{   
    int x,y;   
    spe(int a=1,int b=0) { x=a,y=b; }                    
    spe operator*(const spe b) const 
    {     
        spe c;    
        c.x=(ll)x*b.x%mod;    
        c.y=y+b.y;       
        return c;    
    }      
    spe operator/(const spe b) const 
    {
        spe c;   
        c.x=(ll)x*INV(b.x)%mod;   
        c.y=y-b.y;      
        return c;    
    }   
    inline void modify(int v) 
    {
        if(v) x=v,y=0;      
        else x=1,y=1;     
    }      
    inline int val() { return y?0:x; }                      
}G[N],G1[N];          

// 矩阵乘法    
struct matrix
{
    int a[2][2];     
    matrix() { memset(a,0,sizeof(a)); }
    int *operator[](int x) { return a[x]; }                       
    matrix operator*(matrix b) const 
    {
        matrix c;     
        for(int i=0;i<2;++i) 
        {
            for(int j=0;j<2;++j) 
                for(int k=0;k<2;++k)    
                    (c[i][j]+=(ll)a[i][k]*b[k][j]%mod)%=mod; 
        }
        return c;     
    }
}tmp[N],F[N],F1[N],tmp1[N];           

// 动态树    
#define lson s[x].ch[0] 
#define rson s[x].ch[1]   

int sta[N];    
struct data { int ch[2],f,r; }s[N];      
inline int get(int x) { return s[s[x].f].ch[1]==x; }   
inline int isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x; }

void pushup(int x) 
{   
    s[x].r=x;   
    // 叶节点特判   
    if(leaf[x])   
    {    
        F[x][0][0]=1,F[x][0][1]=F[x][1][1]=0;     
        F[x][1][0]=G[x].val();    
        F1[x][0][0]=1,F1[x][0][1]=F1[x][1][1]=0;   
        F1[x][1][0]=G1[x].val();    
        if(lson) 
        {
            F[x]=F[lson]*F[x];    
            F1[x]=F1[lson]*F1[x];    
        }
    }   
    else if(CUR==x) 
    {
        F[x][0][0]=1,F[x][0][1]=F[x][1][1]=0;
        F1[x][0][0]=1,F1[x][0][1]=F1[x][1][1]=0;         
        if(dep[x]%2==1) 
        {
            F[x][1][0]=(1-G[x].val()+mod)%mod;   
            F1[x][1][0]=G1[x].val();   
        }
        if(dep[x]%2==0) 
        {
            F[x][1][0]=G[x].val();   
            F1[x][1][0]=(1-G1[x].val()+mod)%mod;   
        }
        if(lson) F[x]=F[lson]*F[x],F1[x]=F1[lson]*F1[x];    
    }  
    else 
    { 
        F[x]=tmp[x];   
        F1[x]=tmp1[x];   
        if(lson) 
        {
            F[x]=F[lson]*F[x];     
            F1[x]=F1[lson]*F1[x];   
        }
        if(rson)
        {
            s[x].r=s[rson].r;    
            F[x]=F[x]*F[rson];     
            F1[x]=F1[x]*F1[rson];   
        }
    }
}          
void rotate(int x) 
{
    int old=s[x].f,fold=s[old].f,which=get(x); 
    if(!isr(old)) 
        s[fold].ch[s[fold].ch[1]==old]=x;      
    s[old].ch[which]=s[x].ch[which^1]; 
    if(s[old].ch[which]) 
        s[s[old].ch[which]].f=old;      
    s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;   
    pushup(old),pushup(x);    
} 
void splay(int x) 
{
    int u=x,fa; 
    for(;!isr(u);u=s[u].f);
    CUR=s[u].r;            
    for(u=s[u].f;(fa=s[x].f)!=u;rotate(x))   
        if(s[fa].f!=u)   
            rotate(get(fa)==get(x)?fa:x);    
    CUR=0;    
}    
void Access(int x) 
{
    for(int y=0;x;y=x,x=s[x].f) 
    {
        splay(x);                   
        if(rson) 
        {     
            spe t,t2;    
            if(dep[x]%2==1) 
            {
                t.modify(1-F[rson][1][0]+mod);  
                t2.modify(F1[rson][1][0]); 
            }
            if(dep[x]%2==0) 
            {
                t.modify(F[rson][1][0]);    
                t2.modify(1-F1[rson][1][0]+mod);    
            }
            G[x]=G[x]*t;    
            G1[x]=G1[x]*t2;   
        }
        if(y) 
        {            
            spe t,t2;                 
            if(dep[x]%2==1) 
            {
                t.modify(1-F[y][1][0]+mod);   
                t2.modify(F1[y][1][0]); 
            }
            if(dep[x]%2==0) 
            {
                t.modify(F[y][1][0]);     
                t2.modify(1-F1[y][1][0]+mod); 
            }
            G[x]=G[x]/t;           
            G1[x]=G1[x]/t2;  
        }
        // 取 max        
        if(dep[x]%2==1) 
        {                
            tmp[x][1][1]=G[x].val();      
            tmp[x][1][0]=(1-G[x].val()+mod)%mod;       

            tmp1[x][1][1]=G1[x].val(); 
            tmp1[x][1][0]=0; 
        } 
        else 
        {   
            tmp[x][1][1]=G[x].val();    
            tmp[x][1][0]=0;  

            tmp1[x][1][1]=G1[x].val();  
            tmp1[x][1][0]=(1-G1[x].val()+mod)%mod;     
        }  
        rson=y;
        if(!y) CUR=x;    
        pushup(x),CUR=0;   
    }
}
#undef lson 
#undef rson  

void dfs(int x,int ff) 
{ 
    fa[x]=ff,dep[x]=dep[ff]+1,leaf[x]=1;         
    if(dep[x]%2==1) an[x]=-N; 
    else an[x]=N; 
    for(int i=hd[x];i;i=nex[i]) 
    {
        int y=to[i];   
        if(y==ff) continue;     
        leaf[x]=0,dfs(y,x),si[x]+=si[y];   
        if(dep[x]%2==1) 
            an[x]=max(an[x],an[y]);   
        else 
            an[x]=min(an[x],an[y]);   
    }     
    if(leaf[x]==1) si[x]=1,an[x]=x;    
} 

void dfs2(int x) 
{    
    s[x].f=fa[x];               
    G[x].modify(1);  
    G1[x].modify(1);          
    for(int i=hd[x];i;i=nex[i])  
    {
        int y=to[i]; 
        if(y==fa[x]) continue;   
        dfs2(y);        
        spe t,t2;    
        // 取 max   
        if(dep[x]%2==1) 
        {          
            t.modify((1-F[y][1][0]+mod)%mod);     
            t2.modify(F1[y][1][0]);  
            G[x]=G[x]*t;           
            G1[x]=G1[x]*t2;     
        } 
        // 取 min  
        else
        {         
            t.modify(F[y][1][0]);     
            t2.modify((1-F1[y][1][0]+mod)%mod);   
            G[x]=G[x]*t;   
            G1[x]=G1[x]*t2;    
        }
    }             
    if(leaf[x]) 
    {
        if(x==W) G[x].modify(0),G1[x].modify(0);  
        else 
        {
            if(x>W) 
            {    
                G[x].modify(1); 
                G1[x].modify(0);  
                if(x-det<W) G1[x].modify(inv[2]);     
            } 
            if(x<W) 
            {     
                G[x].modify(0);   
                if(x+det>W) G[x].modify(inv[2]);  
                G1[x].modify(1);      
            }
        }
    }  
    else
    { 
        tmp[x][0][0]=1;   
        tmp[x][0][1]=0;   
        tmp[x][1][1]=G[x].val();   

        tmp1[x][0][0]=1; 
        tmp1[x][0][1]=0; 
        tmp1[x][1][1]=G1[x].val(); 

        if(dep[x]%2==0) tmp[x][1][0]=0,tmp1[x][1][0]=(1-G1[x].val()+mod)%mod; 
        if(dep[x]%2==1) tmp[x][1][0]=(1-G[x].val()+mod)%mod,tmp1[x][1][0]=0;   
    }
    CUR=x,pushup(x),CUR=0;       
}

int Ans[N];  
int out[N];  

int main() 
{
    // setIO("input");      
    inv[0]=1;   
    for(int i=1;i<N;++i)  
        inv[i]=INV(i);   
    int L,R; 
    scanf("%d%d%d",&n,&L,&R);  
    for(int i=1;i<n;++i) 
    {
        int x,y;  
        scanf("%d%d",&x,&y);   
        add(x,y),add(y,x);  
    }
    dfs(1,0),W=an[1];  
    for(int i=W;i;i=fa[i]) --si[i];    
    det=1;   
    dfs2(1);      
    int co=qpow(2,si[1]);   
    for(det=1;det<n;++det) 
    {
        int p=W+1-det;         
        if(p>0&&leaf[p]&&p!=W) 
        {
            Access(p),splay(p);     
            G[p].modify(inv[2]);   
            pushup(p);     
        }
        int q=W-1+det;  
        if(q<=n&&leaf[q]&&q!=W) 
        {
            Access(q),splay(q);  
            G1[q].modify(inv[2]); 
            pushup(q); 
        }
        Access(1),splay(1);      

        Ans[det]=(ll)(F[1][1][0]+F1[1][1][0]-(ll)F[1][1][0]*F1[1][1][0]%mod+mod)%mod; 
        Ans[det]=(ll)Ans[det]*co%mod;    

        int cu=(det==1?co:0);   

        out[det]=(ll)(Ans[det]-Ans[det-1]+cu+mod)%mod;   

    }     
    out[n]=(ll)(co-Ans[n-1]+mod-1)%mod;  
    for(int i=L;i<=R;++i) printf("%d ",out[i]);    
    return 0; 
}

  

猜你喜欢

转载自www.cnblogs.com/guangheli/p/12799727.html
今日推荐