HYSBZ 1036 树的统计Count(树链剖分)题解

思路:

树链剖分,不知道说什么...我连模板都不会用

代码:

#include<map>
#include<ctime>
#include<cmath>
#include<stack>
#include<queue>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define ll long long
const int maxn = 60000+5;
const int INF = 0x3f3f3f3f;
using namespace std;
int n,m,tol;
int cnt;
struct Node{
    int u,v,next;
}node[maxn << 1];
int head[maxn];
int fa[maxn],siz[maxn],dep[maxn],son[maxn],id[maxn],top[maxn];
//父节点,子树大小,深度,重儿子,dfs执行顺序,重链的起始点

void addnode(int u,int v){
    node[tol].u = u;
    node[tol].v = v;
    node[tol].next = head[u];
    head[u] = tol++;
}
void dfs1(int u,int father,int depth){  //确定fa,siz,dep,son
    son[u] = 0;
    dep[u] = depth;
    fa[u] = father;
    siz[u] = 1;
    for(int i = head[u];~i;i = node[i].next){
        int v = node[i].v;
        if(v == father) continue;   //储存无向边,需判断
        dfs1(v,u,depth + 1);
        siz[u] += siz[v];
        if(siz[v] > siz[son[u]])    //更新重儿子
            son[u] = v;
    }
}
void dfs2(int u,int tp){    //按dfs执行顺序编号
    top[u] = tp;    //重链的起始点
    id[u] = ++cnt;  //编号
    if(!son[u]) return;
    //有重儿子,是重链,继续往下
    dfs2(son[u],tp);
    for(int i = head[u];~i;i = node[i].next){
        int v = node[i].v;
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v,v);
    }
}

/*
线段树
*/
int sum[maxn<<2],mx[maxn<<2];

void push_up(int rt){
    sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
    mx[rt] = max(mx[rt << 1],mx[rt << 1 | 1]);
}
/*void build(int l,int r,int rt){
    sum[rt] = 0;
    mx[rt] = -INF;  //注意
    if(l == r){
        sum[rt] = mx[rt] = val[l];
        return;
    }
    int m = (l + r) >> 1;
    build(l,m,rt << 1);
    build(m + 1,r,rt << 1 | 1);
    push_up(rt);
}*/
void update(int pos,int val,int l,int r,int rt){
    if(l == r){
        sum[rt] = mx[rt] = val;
        return;
    }
    int m = (l + r) >> 1;
    if(pos <= m)
        update(pos,val,l,m,rt << 1);
    else
        update(pos,val,m + 1,r,rt << 1 | 1);
    push_up(rt);
}
int queryS(int L,int R,int l,int r,int rt){
    if(L <= l && R >= r){
        return sum[rt];
    }
    int m = (l + r) >> 1;
    int ans = 0;
    if(L <= m)
        ans += queryS(L,R,l,m,rt << 1);
    if(R > m)
        ans += queryS(L,R,m + 1,r,rt << 1 | 1);
    return ans;
}
int queryM(int L,int R,int l,int r,int rt){
    if(L <= l && R >= r){
        return mx[rt];
    }
    int m = (l + r) >> 1;
    int ans = -INF;
    if(L <= m)
        ans = max(ans,queryM(L,R,l,m,rt << 1));
    if(R > m)
        ans = max(ans,queryM(L,R,m + 1,r,rt << 1 | 1));
    return ans;
}

//////////////////////////////////////////
int FindS(int u,int v){
    int ans = 0;
    while(top[u] != top[v]){    //不在同一条链上,就从u所在转移到v所在
        if(dep[top[u]] < dep[top[v]]){
            swap(u,v);
        }
        ans += queryS(id[top[u]],id[u],1,n,1);
        u = fa[top[u]];
    }
    if(u == v){
        return (ans + queryS(id[u],id[v],1,n,1));
    }
    else{
        if(dep[u] > dep[v]){
            swap(u,v);
        }
        return (ans + queryS(id[u],id[v],1,n,1));
    }
}
int FindM(int u,int v){
    int ans = -INF;
    while(top[u] != top[v]){    //不在同一条链上,就从u所在转移到v所在
        if(dep[top[u]] < dep[top[v]]){
            swap(u,v);
        }
        ans = max(ans,queryM(id[top[u]],id[u],1,n,1));
        u = fa[top[u]];
    }
    if(u == v){
        return max(ans,queryM(id[u],id[v],1,n,1));
    }
    else{
        if(dep[u] > dep[v]){
            swap(u,v);
        }
        return max(ans,queryM(id[u],id[v],1,n,1));
    }
}
///////////////////////////////////////////////////

void init() {
    tol = cnt = 0;
    memset(fa,0,sizeof fa);
    memset(siz,0,sizeof siz);
    memset(son,0,sizeof son);
    memset(dep,0,sizeof dep);
    memset(top,0,sizeof top);
    memset(id,0,sizeof id);
    memset(sum,0,sizeof sum);
    memset(node,0,sizeof node);
    memset(mx,0,sizeof mx);
    memset(head,-1,sizeof head);
}
int main(){
    int q;
    while(~scanf("%d",&n)){
        int u,v;
        init();
        for(int i = 1;i < n;i++){
            scanf("%d%d",&u,&v);
            addnode(u,v);
            addnode(v,u);
        }
        dfs1(1,0,1);
        dfs2(1,1);
        for(int i = 1;i <= n;i++){
            int w;
            scanf("%d",&w);
            update(id[i],w,1,n,1);
        }
        scanf("%d",&q);
        char s[20];
        while(q--){
            scanf("%s%d%d",s,&u,&v);
            if(s[0] == 'C'){
                update(id[u],v,1,n,1);
            }
            else if(s[1] == 'S'){
                printf("%d\n",FindS(u,v));
            }
            else{
                printf("%d\n",FindM(u,v));
            }
        }
    }
    return 0;
}
J

猜你喜欢

转载自blog.csdn.net/qq_14938523/article/details/81146124
今日推荐