BZOJ1036:[ZJOI2008]树的统计Count(树链剖分)

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

题解:

树链剖分模板题。

#include<bits/stdc++.h>
using namespace std;
const int maxn=40010;
struct e {
    int u,v,next;
}edge[maxn*2];


//树链剖分部分
int head[maxn],tot;
int top[maxn];//top[v]表示v所在的重链的顶端结点
int fa[maxn];//父亲节点
int deep[maxn];//深度
int num[maxn];//num[v]表示以v为根的子树的节点数
int p[maxn];//p[v]表示v在线段树中的位置
int fp[maxn];//和p数组相反
int son[maxn];//重儿子
int pos;
void init () {
    tot=0;
    memset(head,-1,sizeof(head));
    pos=0;
    memset(son,-1,sizeof(son));
}
void addedge (int u,int v) {
    edge[tot].u=u;
    edge[tot].v=v;
    edge[tot].next=head[u];
    head[u]=tot++;
}
void dfs (int u,int pre,int d) {
    deep[u]=d;
    fa[u]=pre;
    num[u]=1;
    for (int i=head[u];i!=-1;i=edge[i].next) {
        int v=edge[i].v;
        if (v!=pre) {
            dfs(v,u,d+1);
            num[u]+=num[v];
            if (son[u]==-1||num[v]>num[son[u]])
                son[u]=v;
        }
    }
}
void getpos (int u,int sp) {
    top[u]=sp;
    p[u]=pos++;
    fp[p[u]]=u;
    if (son[u]==-1) return;
    getpos(son[u],sp);
    for (int i=head[u];i!=-1;i=edge[i].next) {
        int v=edge[i].v;
        if (v!=son[u]&&v!=fa[u]) getpos(v,v);
    }
}


//线段树部分
struct node {
    int l,r,sum,Max;
}segTree[maxn*3];
void push_up (int i) {
    segTree[i].sum=segTree[i<<1].sum+segTree[i<<1|1].sum;
    segTree[i].Max=max(segTree[i<<1].Max,segTree[i<<1|1].Max);
}
int s[maxn];
void build (int i,int l,int r) {
    segTree[i].l=l;
    segTree[i].r=r;
    if (l==r) {
        segTree[i].sum=segTree[i].Max=s[fp[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    push_up(i);
}
void update (int i,int k,int val) {
    if (segTree[i].l==k&&segTree[i].r==k) {
        segTree[i].sum=segTree[i].Max=val;
        return;
    }
    int mid=(segTree[i].l+segTree[i].r)>>1;
    if (k<=mid)
        update(i<<1,k,val);
    else
        update(i<<1|1,k,val);
    push_up(i);
}
int queryMax (int i,int l,int r) {
    if (segTree[i].l==l&&segTree[i].r==r)
        return segTree[i].Max;
    int mid=(segTree[i].l+segTree[i].r)>>1;
    if (r<=mid)
        return queryMax(i<<1,l,r);
    else if (l>mid)
        return queryMax(i<<1|1,l,r);
    else
        return max(queryMax(i<<1,l,mid),queryMax(i<<1|1,mid+1,r));
}
int querySum (int i,int l,int r) {
    if (segTree[i].l==l&&segTree[i].r==r)
        return segTree[i].sum;
    int mid=(segTree[i].l+segTree[i].r)>>1;
    if (r<=mid)
        return querySum(i<<1,l,r);
    else if (l>mid)
        return querySum(i<<1|1,l,r);
    else
        return querySum(i<<1,l,mid)+querySum(i<<1|1,mid+1,r);
}
int findMax (int u,int v) {
    //查询u->v路径上节点的最大值
    int f1=top[u];
    int f2=top[v];
    int tmp=-1e9;
    while (f1!=f2) {
        if (deep[f1]<deep[f2]) {
            swap(f1,f2);
            swap(u,v);
        }
        tmp=max(tmp,queryMax(1,p[f1],p[u]));
        u=fa[f1];
        f1=top[u];
    }
    if (deep[u]>deep[v])
        swap(u,v);
    return max(tmp,queryMax(1,p[u],p[v]));
}
int findSum (int u,int v) {
    //查询u->v路径上结点的权值和
    int f1=top[u];
    int f2=top[v];
    int tmp=0;
    while (f1!=f2) {
        if (deep[f1]<deep[f2]) {
            swap(f1,f2);
            swap(u,v);
        }
        tmp+=querySum(1,p[f1],p[u]);
        u=fa[f1];
        f1=top[u];
    }
    if (deep[u]>deep[v])
        swap(u,v);
    return tmp+querySum(1,p[u],p[v]);
}
int main () {
    int n,q;
    char op[20];
    int u,v;
    while (scanf("%d",&n)==1) {
        init();
        for (int i=1;i<n;i++) {
            scanf("%d%d",&u,&v);
            addedge(u,v);
            addedge(v,u);
        }
        for (int i=1;i<=n;i++) scanf("%d",&s[i]);
        dfs(1,0,0);
        getpos(1,1);
        build(1,0,pos-1);
        scanf("%d",&q);
        while (q--) {
            scanf("%s%d%d",op,&u,&v);
            if (op[0]=='C')
                update(1,p[u],v);
            else if (strcmp(op,"QMAX")==0)
                printf("%d\n",findMax(u,v));
            else
                printf("%d\n",findSum(u,v));
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zhanglichen/p/12805530.html
今日推荐