保卫王国 题解(与csdn一致)

【NOIP2018】保卫王国 解题报告(来自我的csdn博客)

题解

这道题是一道动态dp的题。我们可以这样来考虑这道题:

  1. 想到树,想到树链剖分。树链剖分的dfs1过程中,我们进行动态规划(树上)。我们可以求出不改变树的情况下的最小花费。状态f的定义:
    \[ f[i][0]表示该节点不驻兵时其子树的最小花费\\f[i][1]表示该节点驻兵时他的子树的最小花费。 \]

    状态转移方程:

\[ \begin{cases} f [ u ] [ 0 ] = \sum f [ v ] [ 0 ] ( v \in u 的子树 ) \\ f [ u ] [ 1 ] = \sum \min ( f [ v ] [ 0 ] , f [ v ] [ 1 ] ) ( v \in u 的子树 ) \end{cases} \]

  1. 树链剖分一定在重链和轻儿子上面处理。但是,这个状态转移方程和每个儿子都有关系,于是,我们可以简化:
    \[ \begin{cases} \begin{array} { c } { f [ u ] [ 0 ] = f [ \operatorname { son } [ u ] ] [ 0 ] + g [ u ] [ 0 ] } \\ { f [ u ] [ 1 ] = \min ( f [ \operatorname { son } [ u ] ] [ 1 ] , f [ \operatorname { son } [ u ] ] [ 0 ] ) + g [ u ] [ 0 ] } \end{array} \\ \end{cases} \\ 因为我们知道f,所以我们要求g,变形得\\ \begin{cases} g [ u ] [ 0 ] = f [ u ] [ 0 ] - f [ \operatorname { son } [ u ] ] [ 0 ] \\ g [ u ] [ 1 ] = f [ u ] [ 1 ] - \min ( f [ \operatorname { son } [ u ] ] [ 1 ] , f [ \operatorname { son } [ u ] ] [ 0 ] ) \end{cases} \]

  2. 这道题的核心在于动态dp。

    普通的矩阵乘法是
    \[ \left[ \begin{array} { l l } { x 1 } & { x 2 } \\ { y 1 } & { y 2 } \end{array} \right] \times \left[ \begin{array} { c c } { x 3 } & { x 4 } \\ { y 3 } & { y 4 } \end{array} \right] = \left[ \begin{array} { c c } { x 1 \times x 3 + x 2 \times y 3 } & { x 1 \times x 4 + x 2 \times y 4 } \\ { y 1 \times x 3 + y 2 \times y 3 } & { y 1 \times x 4 + y 2 \times y 4 } \end{array} \right] \]
    我们转换为这样的运算:定义新运算#(程序中是*)(为什么?往后读)
    \[ \left[ \begin{array} { l l } { x 1 } & { x 2 } \\ { y 1 } & { y 2 } \end{array} \right] \# \left[ \begin{array} { c c } { x 3 } & { x 4 } \\ { y 3 } & { y 4 } \end{array} \right] = \left[ \begin{array} { c c } { \min ( x 1 + x 3 , x 2 + y 3 ) } & { \min ( x 1 + x 4 , x 2 + y 4 ) } \\ { \min ( y 1 + x 3 , y 2 + y 3 ) } & { \min ( y 1 + x 4 , y 2 + y 4 ) } \end{array} \right] \]
    那么这个状态转移方程可以通过定义得到递推和通项:
    \[ \left[ \begin{array} { l } { f [ u ] [ 0 ] } \\ { f [ u ] [ 1 ] } \end{array} \right] =\left[ \begin{array} { c c } { \infty } & { g [ u ] [ 0 ] } \\ { g [ u ] [ 1 ] } & { g [ u ] [ 1 ] } \end{array} \right] \# \left[ \begin{array} { l } { f [ \operatorname { son } [ u ] ] [ 0 ] } \\ { f [ \operatorname { son } [ u ] ] [ 1 ] } \end{array} \right] \\ \left[ \begin{array} { c c } { f [ u ] [ 0 ] } \\ { f [ u ] [ 1 ] } \end{array} \right] = \left[ \begin{array} { c c } { \infty } & { g [ u ] [ 0 ] } \\ { g [ u ] [ 1 ] } & { g [ u ] [ 1 ] } \end{array} \right] \# \left[ \begin{array} { c c } { \infty } & { g [ \operatorname { son } [ u ] ] [ 0 ] } \\ { g [ \operatorname { son } [ u ] ] [ 1 ] } & { g [ \operatorname { son } [ u ] ] [ 1 ] } \end{array} \right] \# \cdots \# \left[ \begin{array} { c c } { \infty } & { g [ v ] [ 0 ] } \\ { g [ v ] [ 1 ] } & { g [ v ] [ 1 ] } \end{array} \right] \# \left[ \begin{array} { l } { 0 } \\ { 0 } \end{array} \right] \]
    显然,u到son[u]一直到v(重链末端)在同一条重链上,因此seg编号是连续的。

    连续的区间操作,想到什么?线段树!

  3. 那么我们面临最后一个问题:怎么修改?

    我们可以先修改当前节点,再做树链剖分,更改链头的父亲节点。具体细节可以看程序。

    如果必须驻兵,那么节点权-inf

    如果不能驻兵,那么节点权+inf

    最后在修改回去即可。

    考虑
    \[ \left[ \begin{array} { l } { f [ 1 ] [ 0 ] } \\ { f [ 1 ] [ 1 ] } \end{array} \right] = \left[ \begin{array} { l l } { x 1 } & { x 2 } \\ { y 1 } & { y 2 } \end{array} \right] \# \left[ \begin{array} { l } { 0 } \\ { 0 } \end{array} \right] = \left[ \begin{array} { c } { \min ( x 1 , x 2 ) } \\ { \min ( y 1 , y 2 ) } \end{array} \right] \]
    最后输出结果是
    \[ min(f[1][0],f[1][1]) \]

代码#1

#include<bits/stdc++.h>
using namespace std;
 
typedef long long ll;
const int N=100005;
const ll inf=1e13;    // 这里inf不能太大 谨防爆ll 
 
struct matrix{        //矩阵 
    ll a[2][2];
    //下面是#运算,我重载了乘号 
    inline friend matrix operator *(const matrix& a,const matrix& b) {  
        matrix c;
        c.a[0][0]=min(a.a[0][0]+b.a[0][0],a.a[0][1]+b.a[1][0]); //常数优化,一个个写 
        c.a[0][1]=min(a.a[0][0]+b.a[0][1],a.a[0][1]+b.a[1][1]);
        c.a[1][0]=min(a.a[1][0]+b.a[0][0],a.a[1][1]+b.a[1][0]);
        c.a[1][1]=min(a.a[1][0]+b.a[0][1],a.a[1][1]+b.a[1][1]);
        return c;
    }
} T[N*4],val[N]; //T是线段树,val是修改后的矩阵,方便modify运算
 
int n,m,head[N],to[N*2],nxt[N*2],tot;  //邻接表 
int anc[N],dep[N],son[N],siz[N];       //树的基本特征 
int seg[N],rev[N],scnt,top[N],tail[N]; //rev:反差表,tail:链尾 
ll dp[N][2],p[N];                      //一定一定一定开ll 
 
inline void addedge(int x,int y) {     //领接表 
    nxt[++tot]=head[x],head[x]=tot,to[tot]=y;
    nxt[++tot]=head[y],head[y]=tot,to[tot]=x;
}
 
void build(int k,int l,int r) {        //线段树建树 
    if (l==r) {
        int u=rev[l];
        T[k].a[0][0]=1e18;             //建新节点 
        T[k].a[0][1]=dp[u][0]-dp[son[u]][1];
        T[k].a[1][0]=T[k].a[1][1]=dp[u][1]-min(dp[son[u]][0],dp[son[u]][1]);
        val[l]=T[k];
        return;
    }
    int mid=l+r>>1;
    build(k*2,l,mid);
    build(k*2+1,mid+1,r);
    T[k]=T[k*2]*T[k*2+1];              //push_up 
}
 
void modify(int k,int l,int r,int x) { //单点修改 
    if (l>=x && r<=x) return (void)(T[k]=val[l]);
    int mid=l+r>>1;
    if (x<=mid) modify(k*2,l,mid,x);
    else modify(k*2+1,mid+1,r,x);
    T[k]=T[k*2]*T[k*2+1];              //push_up 
}
 
matrix query(int k,int l,int r,int x,int y) { //区间查询 
    if (x<=l && r<=y) return T[k];
    int mid=l+r>>1;
    if (y<=mid) return query(k*2,l,mid,x,y);  //这样写更快 
    if (x>mid) return query(k*2+1,mid+1,r,x,y);
    return query(k*2,l,mid,x,y)*query(k*2+1,mid+1,r,x,y);
}
 
void dfs1(int u,int fa) { //树剖dfs1,顺便求树上dp 
    anc[u]=fa,dep[u]=dep[fa]+1,son[u]=0,siz[u]=1;
    dp[u][0]=0,dp[u][1]=p[u];
    for (int i=head[u];i;i=nxt[i]) {
        int v=to[i]; if (v==fa) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if (siz[v]>siz[son[u]]) son[u]=v;
        dp[u][0]+=dp[v][1];
        dp[u][1]+=min(dp[v][0],dp[v][1]); 
    }
}
 
void dfs2(int u,int tp) { //树剖dfs2 
    top[u]=tp,seg[u]=++scnt,rev[scnt]=u,tail[tp]=scnt;
    if (son[u]!=0) dfs2(son[u],tp);
    for (int i=head[u];i;i=nxt[i]) {
        int v=to[i];
        if (v==anc[u] || v==son[u]) continue;
        dfs2(v,v);
    }
}
 
void solve(int u,ll x) {     //这里x是delta 
    p[u]+=x;                 //p更改 
    val[seg[u]].a[1][0]+=x;  //对val进行相应更改 
    val[seg[u]].a[1][1]+=x;  //勿忘seg 
    while (u) {              //树剖惯用写法 
        matrix old=query(1,1,n,seg[top[u]],tail[top[u]]); //询问旧 
        modify(1,1,n,seg[u]);                             //更改 
        matrix nw=query(1,1,n,seg[top[u]],tail[top[u]]);  //询问新 
        u=anc[top[u]];                                    //u往上跳 
        if (!u) break;                                    //到0就结束 
        ll oldf,newf;
        oldf=min(old.a[1][1],old.a[1][0]);                //传递给祖先 
        newf=min(nw.a[1][1],nw.a[1][0]);
        val[seg[u]].a[0][1]+=newf-oldf;
        oldf=min(oldf,min(old.a[0][0],old.a[0][1]));
        newf=min(newf,min(nw.a[0][0],nw.a[0][1]));
        val[seg[u]].a[1][0]+=newf-oldf;
        val[seg[u]].a[1][1]+=newf-oldf;
    }
}
 
int main() {
    freopen("defense.in","r",stdin);
    freopen("defense.out","w",stdout);
    char sos[8];                                  //sos 
    scanf("%d%d%s",&n,&m,sos);
    for (int i=1;i<=n;++i) scanf("%lld",&p[i]);
    for (int i=1,u,v;i<n;++i) {
        scanf("%d%d",&u,&v); addedge(u,v);
    }
    dfs1(1,0); 
    dfs2(1,1); 
    build(1,1,n);  //勿忘建树 惨痛教训 
    for (int i=0;i<m;++i) {  //进行询问 
        int a,x,b,y; scanf("%d%d%d%d",&a,&x,&b,&y);
        if (dep[a]<dep[b]) swap(a,b),swap(x,y);
        if (anc[a]==b && y==0 && x==0) {  //如果无法达成 
            printf("-1\n");
            continue;
        }
        solve(a,x?-inf:inf);              //如果一定驻兵,设成-inf 
        solve(b,y?-inf:inf);              //如果不能驻兵,设成 inf 
        matrix temp=query(1,1,n,1,tail[1]); //先询问 
        ll ans=min(min(temp.a[0][0],temp.a[0][1]),min(temp.a[1][1],temp.a[1][0]));
        //四者去最小值 
        if (x) ans+=inf; //因为-了inf 
        if (y) ans+=inf;
        printf("%lld\n",ans);
        solve(a,x?inf:-inf);  //改回去 
        solve(b,y?inf:-inf);
    }
    return 0;
}

那么这份代码的运行结果怎么样呢?我们来看。

在这里插入图片描述

9340ms,不是很理想。那么有没有优化的方法呢?

优化1

树链剖分的线段树修改是不是只在单条重链上面的?是的!所以我们使用传统线段树会产生巨大浪费(不仅是空间上面的,也是时间上面的)。

那么,我们想到一个极(gǒu)妙(pì)的招:有几条重链就建几棵线段树。那么同学们可能会问:啊?!那不会爆空间吗?!

其实上述问题是存在解决方法的。解决方法是主席树的一种简化版本。核心程序如下,同学们可以自己理解一下。

注:指针和面向对象不熟的同学可以自己翻一翻

struct segnode{
    segnode *lc,*rc;
    int l,r;
    matrix g;
    inline void pushup() {g=lc->g*rc->g;}
    void modify(int x) {
        if (x>r || x<l) return;
        if (l==r && r==x) return (void)(g=val[rev[x]]);
        int mid=(l+r)>>1;
        if (x<=mid) lc->modify(x); 
        else rc->modify(x);
        pushup();
    }
    matrix query(int x,int y) {
        if (x<=l && r<=y) return g;
        int mid=l+r>>1;
        if (y<=mid) return lc->query(x,y);
        if (x>mid) return rc->query(x,y);
        return lc->query(x,y)*rc->query(x,y);
    }
} T[N<<1],*pool=T,*tv[N];
//T是线段树数组,存储了线段树的所有信息
//pool指向哪一个位置没有用过(build用)
//tv存储的是每棵树的根节点

segnode* build(int l,int r){
    segnode* k=++pool;
    k->lc=k->rc=NULL;
    k->l=l,k->r=r;
    if (l==r) {
        int u=rev[l];
        k->g.a[0][0]=1e18;
        k->g.a[0][1]=dp[u][0]-dp[son[u]][1];
        k->g.a[1][1]=k->g.a[1][0]=dp[u][1]-min(dp[son[u]][0],dp[son[u]][1]);
        val[u]=k->g;
        return k;
    }
    int mid=l+r>>1;
    k->lc=build(l,mid);
    k->rc=build(mid+1,r);
    k->pushup();
    return k;
}

这个程序用到了指针的思想,为了 \(z \bar u angb\bar i\) 我把它改成了全指针(虽然慢了点)

struct segnode{
    segnode *lc,*rc;
    int l,r;
    matrix g;
    inline void pushup() {g=lc->g*rc->g;}
    segnode(int L,int R){
        lc=rc=NULL;
        l=L,r=R;
        if (l==r) {
           int u=rev[l];
           g.a[0][0]=1e18;
           g.a[0][1]=dp[u][0]-dp[son[u]][1];
           g.a[1][1]=g.a[1][0]=dp[u][1]-min(dp[son[u]][0],dp[son[u]][1]);
           val[u]=g;
        }else{
           int mid=l+r>>1;
           lc=new segnode(l,mid);
           rc=new segnode(mid+1,r);
           pushup();
        }
    }
    void modify(int x) {
        if (l==r && r==x) return (void)(g=val[rev[x]]);
        int mid=(l+r)>>1;
        if (x<=mid) lc->modify(x); 
        else rc->modify(x);
        pushup();
    }
    matrix query(int x,int y) {
        if (x<=l && r<=y) return g;
        int mid=l+r>>1;
        if (y<=mid) return lc->query(x,y);
        if (x>mid) return rc->query(x,y);
        return lc->query(x,y)*rc->query(x,y);
    }
} *tv[N];

优化代码1

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
const int N=100005;
const ll inf=1e13;    // 这里inf不能太大 谨防爆ll 

struct matrix{        //矩阵 
    ll a[2][2];
    下面是#运算,我重载了乘号 
    inline friend matrix operator *(const matrix& a,const matrix& b) {  
        matrix c;
        c.a[0][0]=min(a.a[0][0]+b.a[0][0],a.a[0][1]+b.a[1][0]); //常数优化,一个个写 
        c.a[0][1]=min(a.a[0][0]+b.a[0][1],a.a[0][1]+b.a[1][1]);
        c.a[1][0]=min(a.a[1][0]+b.a[0][0],a.a[1][1]+b.a[1][0]);
        c.a[1][1]=min(a.a[1][0]+b.a[0][1],a.a[1][1]+b.a[1][1]);
        return c;
    }
} val[N];

int n,m,head[N],to[N*2],nxt[N*2],tot;  //邻接表
int anc[N],dep[N],son[N],siz[N];       //树的基本特征 
int seg[N],rev[N],scnt,top[N],tail[N]; //rev:反差表,tail:链尾
ll dp[N][2],p[N];                      //一定一定一定开ll 

struct segnode{
    segnode *lc,*rc;
    int l,r;
    matrix g;
    inline void pushup() {g=lc->g*rc->g;}
    segnode(int L,int R){
        lc=rc=NULL;
        l=L,r=R;
        if (l==r) {
           int u=rev[l];
           g.a[0][0]=1e18;
           g.a[0][1]=dp[u][0]-dp[son[u]][1];
           g.a[1][1]=g.a[1][0]=dp[u][1]-min(dp[son[u]][0],dp[son[u]][1]);
           val[u]=g;
        }else{
           int mid=l+r>>1;
           lc=new segnode(l,mid);
           rc=new segnode(mid+1,r);
           pushup();
        }
    }
    void modify(int x) {
        if (l==r && r==x) return (void)(g=val[rev[x]]);
        int mid=(l+r)>>1;
        if (x<=mid) lc->modify(x); 
        else rc->modify(x);
        pushup();
    }
    matrix query(int x,int y) {
        if (x<=l && r<=y) return g;
        int mid=l+r>>1;
        if (y<=mid) return lc->query(x,y);
        if (x>mid) return rc->query(x,y);
        return lc->query(x,y)*rc->query(x,y);
    }
} *tv[N];

inline void addedge(int x,int y) {     //领接表 
    nxt[++tot]=head[x],head[x]=tot,to[tot]=y;
    nxt[++tot]=head[y],head[y]=tot,to[tot]=x;
}

void dfs1(int u,int fa) { //树剖dfs1,顺便求树上dp 
    anc[u]=fa,dep[u]=dep[fa]+1,son[u]=0,siz[u]=1;
    dp[u][0]=0,dp[u][1]=p[u];
    for (int i=head[u];i;i=nxt[i]) {
        int v=to[i]; if (v==fa) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if (siz[v]>siz[son[u]]) son[u]=v;
        dp[u][0]+=dp[v][1];
        dp[u][1]+=min(dp[v][0],dp[v][1]); 
    }
}

void dfs2(int u,int tp) { //树剖dfs2 
    top[u]=tp,seg[u]=++scnt,rev[scnt]=u;
    if (son[u]!=0) dfs2(son[u],tp);
    else tv[tp]=new segnode(seg[tp],scnt);
    for (int i=head[u];i;i=nxt[i]) {
        int v=to[i];
        if (v==anc[u] || v==son[u]) continue;
        dfs2(v,v);
    }
}

void solve(int u,ll x) {     //这里x是delta 
    p[u]+=x;                 //p更改 
    val[u].a[1][0]+=x;  //对val进行相应更改 
    val[u].a[1][1]+=x;  //勿忘seg 
    while (u) {              //树剖惯用写法 
        segnode* k=tv[top[u]];
        matrix old=k->query(k->l,k->r); //询问旧 
        k->modify(seg[u]);                             //更改 
        matrix nw=k->query(k->l,k->r);  //询问新 
        u=anc[top[u]];                                    //u往上跳 
        if (!u) break;                                    //到0就结束 
        ll oldf,newf;
        oldf=min(old.a[1][1],old.a[1][0]);                //传递给祖先 
        newf=min(nw.a[1][1],nw.a[1][0]);
        val[u].a[0][1]+=newf-oldf;
        oldf=min(oldf,min(old.a[0][0],old.a[0][1]));
        newf=min(newf,min(nw.a[0][0],nw.a[0][1]));
        val[u].a[1][0]+=newf-oldf;
        val[u].a[1][1]+=newf-oldf;
    }
}

int main() {
    //freopen("defense.in","r",stdin);
    //freopen("defense.out","w",stdout);
    char sos[8];                                  //sos 
    scanf("%d%d%s",&n,&m,sos);
    for (int i=1;i<=n;++i) scanf("%lld",&p[i]);
    for (int i=1,u,v;i<n;++i) {
        scanf("%d%d",&u,&v); addedge(u,v);
    }
    dfs1(1,0); 
    dfs2(1,1); 
    for (int i=0;i<m;++i) {
        int a,x,b,y; scanf("%d%d%d%d",&a,&x,&b,&y);
        if (dep[a]<dep[b]) swap(a,b),swap(x,y);
        if (anc[a]==b && y==0 && x==0) {
            printf("-1\n");
            continue;
        }
        solve(a,x?-inf:inf);            
        solve(b,y?-inf:inf);             
        matrix temp=tv[1]->query(tv[1]->l,tv[1]->r);
        ll ans=min(min(temp.a[0][0],temp.a[0][1]),min(temp.a[1][1],temp.a[1][0]));
        if (x) ans+=inf;
        if (y) ans+=inf;
        printf("%lld\n",ans);
        solve(a,x?inf:-inf);
        solve(b,y?inf:-inf);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/JerryZheng2005/p/12181359.html