洛谷P4719 【模板】动态dp 矩阵乘法+树链剖分+线段树

题目描述

给定一棵 n 个点的树,点带点权。
m 次操作,每次操作给定 x , y ,表示修改点 x 的权值为 y
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

题目分析

假如没有修改操作,这题怎么做呢?设 a x x 的点权, f ( x , 0 / 1 ) 表示 x 这个点不选/选的情况下,其子树中的最大权独立集权值大小。那么就有:
f ( x , 0 ) = m a x ( f ( s o n , 0 ) , f ( s o n , 1 ) )
f ( x , 1 ) = a x + f ( s o n , 0 )

现在有了修改操作,就很头疼。一想到动态修改,就想到以线段树为代表的数据结构,一想到在树上,就想到树链剖分。既然是动态DP,那么有一个比较“动感”的东西可以处理DP——
“矩阵乘法”!
为什么说它比较动感呢,是因为将一个DP写成矩乘的形式后,矩乘又是有结合率的,所以可以先算前一半再算后一半再合起来什么的。
考虑将矩乘中的乘法改成加法,加法改成取max操作。

matrix operator * (matrix a,matrix b) {
    matrix c;
    c.t[0][0]=max(a.t[0][0]+b.t[0][0],a.t[0][1]+b.t[1][0]);
    c.t[1][0]=max(a.t[1][0]+b.t[0][0],a.t[1][1]+b.t[1][0]);
    c.t[0][1]=max(a.t[0][0]+b.t[0][1],a.t[0][1]+b.t[1][1]);
    c.t[1][1]=max(a.t[1][0]+b.t[0][1],a.t[1][1]+b.t[1][1]);
    return c;
}

g ( x , 0 / 1 ) 表示对于一个点 x ,在不选/选它的情况下,其轻儿子对其造成的贡献。又因为重儿子的dfs序是 x 的dfs序+1,设 x 的dfs序为1,则有:

[ f ( i , 0 ) f ( i , 1 ) ] = [ g ( i , 0 ) g ( i , 0 ) g ( i , 1 ) 0 ] [ f ( i + 1 , 0 ) f ( i + 1 , 1 ) ]

然后每一次修改操作,我们修改若干条重链,每条重链修改一个点的转移矩阵即可。
一条重链顶端的答案,直接查询这条链对应的dfs序区间的转移矩阵全部相乘的结果。这样也可以获得1的答案。

代码

#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
    int q=0,w=1;char ch=' ';
    while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
    if(ch=='-') w=-1,ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q*w;
}
typedef long long LL;
const int N=100005;
int n,m,tot,tim;
int h[N],ne[N<<1],to[N<<1],dep[N],fa[N],sz[N];
int pos[N],repos[N],ed[N],top[N],son[N];
LL a[N],f[N][2];
void add(int x,int y) {to[++tot]=y,ne[tot]=h[x],h[x]=tot;}
void dfs1(int x,int las) {
    fa[x]=las,dep[x]=dep[las]+1,sz[x]=1;
    for(RI i=h[x];i;i=ne[i])
        if(to[i]!=las) dfs1(to[i],x),sz[x]+=sz[to[i]];
}
void dfs2(int x,int las) {
    int bj=0,mx=0;
    pos[x]=++tim,repos[tim]=x;
    for(RI i=h[x];i;i=ne[i])
        if(to[i]!=las&&sz[to[i]]>mx) mx=sz[to[i]],bj=to[i];
    if(!bj) {ed[top[x]]=pos[x];return;}
    son[x]=bj,top[bj]=top[x],dfs2(bj,x);
    for(RI i=h[x];i;i=ne[i])
        if(to[i]!=las&&to[i]!=bj) top[to[i]]=to[i],dfs2(to[i],x);
}
void dp(int x,int las) {
    f[x][1]=a[x];
    for(RI i=h[x];i;i=ne[i]) {
        if(to[i]==las) continue;
        dp(to[i],x);
        f[x][0]+=max(f[to[i]][0],f[to[i]][1]);
        f[x][1]+=f[to[i]][0];
    }
}

struct matrix{LL t[2][2];}tr[N<<2],QvQ[N];
matrix operator * (matrix a,matrix b) {
    matrix c;
    c.t[0][0]=max(a.t[0][0]+b.t[0][0],a.t[0][1]+b.t[1][0]);
    c.t[1][0]=max(a.t[1][0]+b.t[0][0],a.t[1][1]+b.t[1][0]);
    c.t[0][1]=max(a.t[0][0]+b.t[0][1],a.t[0][1]+b.t[1][1]);
    c.t[1][1]=max(a.t[1][0]+b.t[0][1],a.t[1][1]+b.t[1][1]);
    return c;
}
void build(int s,int t,int i) {
    if(s==t) {
        int x=repos[s];LL g0=0,g1=0;
        for(RI j=h[x];j;j=ne[j])
            if(to[j]!=fa[x]&&to[j]!=son[x])
                g0+=max(f[to[j]][0],f[to[j]][1]),g1+=f[to[j]][0];
        tr[i].t[0][0]=tr[i].t[0][1]=g0,tr[i].t[1][0]=g1+a[x];
        QvQ[s]=tr[i];
        return;
    }
    int mid=(s+t)>>1;
    build(s,mid,i<<1),build(mid+1,t,(i<<1)|1);
    tr[i]=tr[i<<1]*tr[(i<<1)|1];
}
void chan(int x,int s,int t,int i) {
    if(s==t) {tr[i]=QvQ[s];return;}
    int mid=(s+t)>>1;
    if(x<=mid) chan(x,s,mid,i<<1);
    else chan(x,mid+1,t,(i<<1)|1);
    tr[i]=tr[i<<1]*tr[(i<<1)|1];
}
matrix query(int l,int r,int s,int t,int i) {
    if(l<=s&&t<=r) return tr[i];
    int mid=(s+t)>>1;
    if(r<=mid) return query(l,r,s,mid,i<<1);
    if(mid+1<=l) return query(l,r,mid+1,t,(i<<1)|1);
    return query(l,r,s,mid,i<<1)*query(l,r,mid+1,t,(i<<1)|1);
}
matrix getans(int x) {return query(pos[x],ed[x],1,n,1);}//获得一条链顶端的dp值
void work(int x,LL num) {//修改的主体
    QvQ[pos[x]].t[1][0]+=num-a[x],a[x]=num;
    matrix k1,k2;
    while(x) {//往上跳
        k1=getans(top[x]),chan(pos[x],1,n,1),k2=getans(top[x]);
        x=fa[top[x]];if(!x) break;
        //修改一个点的转移矩阵
        QvQ[pos[x]].t[0][0]+=max(k2.t[0][0],k2.t[1][0])-max(k1.t[0][0],k1.t[1][0]);
        QvQ[pos[x]].t[0][1]=QvQ[pos[x]].t[0][0];
        QvQ[pos[x]].t[1][0]+=k2.t[0][0]-k1.t[0][0];
    }
}
int main()
{
    int x,y;
    n=read(),m=read();
    for(RI i=1;i<=n;++i) a[i]=read();
    for(RI i=1;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
    dfs1(1,0),top[1]=1,dfs2(1,0),dp(1,0);
    build(1,n,1);
    while(m--) {
        x=read(),y=read();
        work(x,y);matrix kl=getans(1);
        printf("%lld\n",max(kl.t[0][0],kl.t[1][0]));
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/litble/article/details/81038415