[NOIP2018]保卫王国(树形dp+倍增)

我的倍增解法吊打动态 \(dp\) 全局平衡二叉树没学过

先讲 \(NOIP\) 范围内的倍增解法。

我们先考虑只有一个点取/不取怎么做。

\(f[x][0/1]\) 表示取/不取 \(x\) 后,\(x\) 子树内的最小权覆盖集,\(g[x][0/1]\) 表示取/不取 \(x\) 后,除 \(x\) 子树的最小权覆盖集。那么这两个数组可以 \(O(n)\) 预处理出来。

\[f[x][0]+=f[y][1]\]

\[f[x][1]+=min(f[y][0],f[y][1])\]

\[g[y][0]=g[x][1]+f[x][1]-min(f[y][0],f[y][1])\]

\[g[y][1]=min(g[y][0],g[x][0]+f[x][0]-f[y][1])\]

那么我们可以 \(a\) 表示 \(x\) 结点的状态,那么 \(ans=f[x][a]+g[x][a]\)

现在我们考虑两个点取/不取怎么做。

我们发现每次影响的只有两点 \(lca\) 的子树内,所以考虑倍增。

我们用 \(anc\) 表示 \(x\) 结点上跳 \(2^i\) 层的祖先,那么 \(w[x][i][0/1][0/1]\) 表示 \(x\) 取/不取,\(anc\) 取/不取,\(anc\) 子树 \(-\) \(x\) 子树的最小权覆盖集,这个数组我们可以 \(O(n\log n)\) 预处理出来。

我们每次枚举 \(x\)\(anc\) 的四种状态,然后再枚举 \(x\) 结点上跳 \(2^{i-1}\) 层的祖先的状态,然后直接取个 \(min\) 就可以了。

for(int u=0;u<2;u++)
    for(int v=0;v<2;v++){
        w[i][j][u][v]=inf;
        for(int k=0;k<2;k++)
            w[i][j][u][v]=min(w[i][j][u][v],w[i][j-1][u][k]+w[tmp][j-1][k][v]);
    }

然后再倍增。我们每次想处理 \(w\) 数组一样一直将 \(x\) 结点和 \(y\) 结点向上跳,然后统计答案。

时间复杂度 \(O(n\log n)\)

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=100000+10;
const ll inf=0x7f7f7f7f7f7f;
int n,m,val[maxn],dep[maxn],fa[maxn][18],head[maxn],to[maxn<<1],nxt[maxn<<1],tot;
ll f[maxn][2],g[maxn][2],w[maxn][18][2][2];char op[10];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}

inline void addedge(int x,int y){
    to[++tot]=y;
    nxt[tot]=head[x];
    head[x]=tot;
}

void dfs1(int x,int Fa){
    dep[x]=dep[Fa]+1;
    fa[x][0]=Fa;f[x][1]=val[x];
    for(int i=head[x],y;i;i=nxt[i]){
        y=to[i];
        if(y==Fa) continue;
        dfs1(y,x);
        f[x][0]+=f[y][1];
        f[x][1]+=min(f[y][0],f[y][1]);
    }
}

void dfs2(int x){
    for(int i=head[x],y;i;i=nxt[i]){
        y=to[i];
        if(y==fa[x][0]) continue;
        g[y][0]=g[x][1]+f[x][1]-min(f[y][0],f[y][1]);
        g[y][1]=min(g[y][0],g[x][0]+f[x][0]-f[y][1]);
        dfs2(y);
    }
}

ll solve(int a,int x,int b,int y){
    if(dep[x]<dep[y]) swap(x,y),swap(a,b);
    ll nx[2],ny[2],tx[2]={inf,inf},ty[2]={inf,inf};
    tx[a]=f[x][a];ty[b]=f[y][b];
    for(int i=17;i>=0;i--)
        if(dep[fa[x][i]]>=dep[y]){
            nx[0]=nx[1]=inf;
            for(int j=0;j<2;j++)
                for(int k=0;k<2;k++)
                    nx[j]=min(nx[j],tx[k]+w[x][i][k][j]);
            tx[0]=nx[0];tx[1]=nx[1];x=fa[x][i];
        }
    if(x==y) return tx[b]+g[y][b];
    for(int i=17;i>=0;i--)
        if(fa[x][i]!=fa[y][i]){
            nx[0]=nx[1]=ny[0]=ny[1]=inf;
            for(int j=0;j<2;j++)
                for(int k=0;k<2;k++){
                    nx[j]=min(nx[j],tx[k]+w[x][i][k][j]);
                    ny[j]=min(ny[j],ty[k]+w[y][i][k][j]);
                }
            tx[0]=nx[0];tx[1]=nx[1];x=fa[x][i];
            ty[0]=ny[0];ty[1]=ny[1];y=fa[y][i];
        }
    int lca=fa[x][0];
    ll ans1=f[lca][0]-f[x][1]-f[y][1]+tx[1]+ty[1]+g[lca][0];
    ll ans2=f[lca][1]-min(f[x][0],f[x][1])-min(f[y][0],f[y][1])+min(tx[0],tx[1])+min(ty[0],ty[1])+g[lca][1];
    return min(ans1,ans2);
}

int main()
{
    n=read(),m=read();scanf("%s",op);
    int a,x,b,y,tmp;
    for(int i=1;i<=n;i++) val[i]=read();
    for(int i=1;i<n;i++){
        x=read(),y=read();
        addedge(x,y);addedge(y,x);
    }
    dfs1(1,0);dfs2(1);
    for(int i=1;i<=n;i++){
        tmp=fa[i][0];
        w[i][0][0][0]=inf;
        w[i][0][0][1]=f[tmp][1]-min(f[i][0],f[i][1]);
        w[i][0][1][0]=f[tmp][0]-f[i][1];
        w[i][0][1][1]=w[i][0][0][1];
    }
    for(int j=1;j<=17;j++)
        for(int i=1;i<=n;i++){
            tmp=fa[i][j-1];
            if(fa[tmp][j-1]){
                fa[i][j]=fa[tmp][j-1];
                for(int u=0;u<2;u++)
                    for(int v=0;v<2;v++){
                        w[i][j][u][v]=inf;
                        for(int k=0;k<2;k++)
                            w[i][j][u][v]=min(w[i][j][u][v],w[i][j-1][u][k]+w[tmp][j-1][k][v]);
                    }
            }
        }
    while(m--){
        x=read(),a=read(),y=read(),b=read();
        if(!a&&!b&&(x==fa[y][0]||y==fa[x][0])){
            printf("-1\n");
            continue;
        }
        printf("%lld\n",solve(a,x,b,y));
    }
    return 0;
}

然后就是 \(O(8n\log^2 n)\) 的树剖+线段树维护矩阵的动态 \(dp\) 了。

发现取/不取我们可以用 \(inf\)\(-inf\) 代替,转化为最大权独立集来做。

\(Code\ Below:\)

#include <bits/stdc++.h>
#define int long long
#define lson (rt<<1)
#define rson (rt<<1|1)
using namespace std;
const int maxn=100000+10;
const int inf=1e10;
int n,m,v[maxn],val[maxn],dp[maxn][2],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,num,ans;
int top[maxn],ed[maxn],siz[maxn],son[maxn],fa[maxn],id[maxn],mp[maxn],tim;
char op[5];

struct Matrix{
    int mat[2][2];
    Matrix(){
        memset(mat,0,sizeof(mat));
    }
};
Matrix operator * (const Matrix &a,const Matrix &b){
    Matrix c;
    for(int i=0;i<2;i++)
        for(int j=0;j<2;j++)
            for(int k=0;k<2;k++)
                c.mat[i][j]=max(c.mat[i][j],a.mat[i][k]+b.mat[k][j]);
    return c;
}
Matrix a[maxn],sum[maxn<<2];

inline void read(int &x){
    x=0;bool f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    if(!f) x=-x;
}

void print(int x){
    if(x<0){putchar('-');x=-x;}
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

inline void add(int x,int y){
    to[++tot]=y;
    nxt[tot]=head[x];
    head[x]=tot;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    int maxson=-1;
    for(int i=head[x],y;i;i=nxt[i]){
        y=to[i];
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(siz[y]>maxson){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    mp[tim]=x;
    top[x]=topf;
    ed[topf]=x;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=nxt[i]){
        y=to[i];
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

void treedp(int x){
    dp[x][0]=0;dp[x][1]=val[x];
    for(int i=head[x],y;i;i=nxt[i]){
        y=to[i];
        if(y==fa[x]) continue;
        treedp(y);
        dp[x][0]+=max(dp[y][0],dp[y][1]);
        dp[x][1]+=dp[y][0];
    }
}

inline void pushup(int rt){
    sum[rt]=sum[lson]*sum[rson];
}

void build(int l,int r,int rt){
    if(l == r){
        int x=mp[l],b[2]={0,val[x]};
        for(int i=head[x],y;i;i=nxt[i]){
            y=to[i];
            if(y==fa[x]||y==son[x]) continue;
            b[0]+=max(dp[y][0],dp[y][1]);
            b[1]+=dp[y][0];
        }
        sum[rt].mat[0][0]=sum[rt].mat[0][1]=b[0];
        sum[rt].mat[1][0]=b[1];a[x]=sum[rt];
        return ;
    }
    int mid=(l+r)>>1;
    build(l,mid,lson);
    build(mid+1,r,rson);
    pushup(rt);
}

void update(int x,int l,int r,int rt){
    if(l == r){
        sum[rt]=a[mp[l]];
        return ;
    }
    int mid=(l+r)>>1;
    if(x <= mid) update(x,l,mid,lson);
    else update(x,mid+1,r,rson);
    pushup(rt);
}

Matrix query(int L,int R,int l,int r,int rt){
    if(L <= l && r <= R){
        return sum[rt];
    }
    int mid=(l+r)>>1;
    if(L > mid) return query(L,R,mid+1,r,rson);
    if(R <= mid) return query(L,R,l,mid,lson);
    return query(L,R,l,mid,lson)*query(L,R,mid+1,r,rson);
}

void modify(int x,int y){
    Matrix u,v;
    a[x].mat[1][0]+=y-val[x];val[x]=y;
    while(x){
        u=query(id[top[x]],id[ed[top[x]]],1,n,1);
        update(id[x],1,n,1);
        v=query(id[top[x]],id[ed[top[x]]],1,n,1);
        x=fa[top[x]];
        if(x){
            a[x].mat[0][0]+=max(v.mat[0][0],v.mat[1][0])-max(u.mat[0][0],u.mat[1][0]);
            a[x].mat[0][1]=a[x].mat[0][0];
            a[x].mat[1][0]+=v.mat[0][0]-u.mat[0][0];
        }
    }
}

signed main()
{
    read(n),read(m);
    scanf("%s",op+1);
    int x,c,d,y;
    for(int i=1;i<=n;i++){
        read(val[i]);
        v[i]=val[i];num+=val[i];
    }
    for(int i=1;i<n;i++){
        read(x),read(y);
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    treedp(1);build(1,n,1);
    Matrix u;
    for(int i=1;i<=m;i++){
        read(x),read(c),read(y),read(d);
        if(c==0&&d==0&&(x==fa[y]||y==fa[x])){
            printf("-1\n");
            continue;
        }
        ans=num;
        if(c==0) ans+=inf-val[x];
        if(d==0) ans+=inf-val[y];
        modify(x,(c==0)?inf:-inf);
        modify(y,(d==0)?inf:-inf);
        u=query(id[1],id[ed[1]],1,n,1);
        ans-=max(u.mat[0][0],u.mat[1][0]);
        modify(x,v[x]);modify(y,v[y]);
        print(ans);putchar('\n');
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/owencodeisking/p/10270207.html