2020暑期牛客多校训练营第七场(C)A National Pandemic(树链剖分)

A National Pandemic

原题请看这里

题目描述:

国家可以表示为 n n 个节点 n 1 n-1 条边的图。 F ( x ) F(x) 表示节点 x x 的疫情严重性。有以下三种修改/查询:

  1. 疫情在 x x 节点爆发,严重性为 x x ,对于每个节点 y y F ( y ) F(y) 增加 w d i s t ( x , y ) w-dist(x,y) ,其中 d i s t ( x , y ) dist(x,y) 表示节点 x x 到节点 y y 路径上边的数量。
  2. 将节点 x x F ( x ) F(x) 更新为 m i n ( F ( x ) , 0 ) min(F(x), 0)
  3. 询问节点 x x F ( x ) F(x)

输入描述:

有多个测试用例。 输入的第一行包含一个整数 T ( 1 T 5 ) T(1 \leq T \leq 5) ,表示测试用例的数量。
对于每个测试用例,第一行包含两个整数 n m ( 1 n m 5 × 1 0 4 ) n,m(1 \leq n,m \leq 5 \times 10 ^ 4) ,代表城市的数量以及事件和查询的数量。 以下 n 1 n-1 行描述了该国家/地区的所有路径,每条路径均包含两个整数 x y ( 1 x y n ) x,y(1 \leq x,y \leq n) ,代表城市 x x y y 之间的道路。 以下 m m 行描述了所有事件,每个事件均以整数 o p t ( 1 o p t 3 ) \mathit {opt}(1 \leq \mathit {opt} \leq 3) 开始,并且如果 o p t \mathit{opt}

  1. 在同一行中将有两个整数 x w ( 1 x n 0 w 10000 ) x,w(1 \leq x \leq n,0 \leq w \leq 10000) 。 这是指上面描述中的事件1。
  2. 在同一行中将有一个整数 x ( 1 x n ) x(1 \leq x \leq n) 。 这是指事件2。
  3. 在同一行中将有一个整数 x ( 1 x n ) x(1 \leq x \leq n) 。 这是指您需要答复的查询。

输出描述:

每个查询输出一个整数。

样例输入:

1
5 6
1 2
1 3
2 4
2 5
1 1 5
3 4
2 1
1 2 7
3 3
3 1

样例输出:

3
9
6

思路:

首先,我们对每一个操作进行分析:
o p t = 1 : opt=1:
在树上求距离可以用 l c a lca ,于是我们把 d i s t dist 这个函数化开,设 x , y , l c a x,y,lca 的深度分别为 d e p [ x ] , d e p [ y ] , d e p [ l c a ] dep[x],dep[y],dep[lca]:
w d i s t ( x , y ) w-dist(x,y)
= w ( d e p [ x ] + d e p [ y ] 2 d e p [ l c a ] ) =w-(dep[x]+dep[y]-2dep[lca])
= w d e p [ x ] d e p [ y ] + 2 d e p [ l c a ] =w-dep[x]-dep[y]+2dep[lca]
由此我们发现:当 o p t = = 1 opt==1 w d e p [ x ] w-dep[x] 是固定的,而对于每一个节点, d e p [ y ] dep[y] 也是固定的,所以我们设 A + = w d e p [ x ] , B + + A+=w-dep[x],B++
A A 表示所有关于 x x 的结果, B B 表示 d e p [ y ] dep[y] 的次数
所以对于每次操作,我们都可以查询一下之前所有的结果,即:
f ( y ) = A B d e p [ y ] + 2 ( d e p [ l c a ] ) f(y)=A-B*dep[y]+2\sum(dep[lca])
o p t = 2 : opt=2:
对于这个操作,我们需要考虑正负,所以我们需要储存一下 m i n ( 0 , f ( y ) ) min(0,f(y)) .
所以我们开一个数组存一下每次y的消除结果:
f f y + = m i n ( 0 , f ( y ) ) f ( y ) ff_y+=min(0,f(y))-f(y)
之后我们算答案只要加上 f f y ff_y 就行了:
f ( y ) = A B d e p [ y ] + 2 ( d e p [ l c a ] ) + f f y f(y)=A-B*dep[y]+2\sum(dep[lca])+ff_y
o p t = = 3 opt==3
输出结果即可

A C AC C o d e Code :

代码是我队友写的,好丑好丑呜呜呜呜

#include<iostream>
#include<cstdio>
#include<string.h>
#define ll long long
#define I1 i<<1
#define I2 i<<1|1
using namespace std;
const int MAXN=1e5+5;
int v[MAXN],h[MAXN],fa[MAXN],id[MAXN],ld[MAXN],rd[MAXN],siz[MAXN],dep[MAXN],son[MAXN],top[MAXN],vis[MAXN],ff[MAXN],n,m,t,x,y,op;
ll cnt,tot,A,B;
struct node{int to,next;}a[MAXN<<1];
struct Tree{int l,r,lz;ll sum;}sgt[MAXN<<2],e[MAXN<<2];
void add(int x,int y){
	a[++cnt].to=y;
	a[cnt].next=h[x];
	h[x]=cnt;
}
void p(int x){
    siz[x]=1;
	son[x]=0;
    for(int i=h[x];~i;i=a[i].next){
        int nxt=a[i].to;
        if(nxt==fa[x])continue;
        fa[nxt]=x;
		dep[nxt]=dep[x]+1;
		p(nxt);
		siz[x]+=siz[nxt];
        if(siz[nxt]>siz[son[x]])
			son[x]=nxt;
    }
}
void dfs(int x,int y){
    top[x]=y;
	ld[x]=++tot;
	id[tot]=x;
    if(son[x]) dfs(son[x],y);
    for(int i=h[x],nxt;~i;i=a[i].next){
        nxt=a[i].to;
        if(nxt^son[x]&&nxt^fa[x])
        	dfs(nxt,nxt);
    }
}
void build(int i,int l,int r){
    sgt[i].l=l;
	sgt[i].r=r;
	sgt[i].lz=0;
    if(l==r){
        sgt[i].sum=v[id[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(I1,l,mid);
    build(I2,mid+1,r);
    sgt[i].sum=sgt[I1].sum+sgt[I2].sum;
}
void down(int x){
    if(sgt[x].lz){
        sgt[x].sum+=sgt[x].lz*(sgt[x].r-sgt[x].l+1);
        sgt[x<<1].lz+=sgt[x].lz;
        sgt[x<<1|1].lz+=sgt[x].lz;
        sgt[x].lz=0;
    }
}
ll q(int x,int l,int r){
    if(sgt[x].l==l&&sgt[x].r==r) return sgt[x].sum+sgt[x].lz*(r-l+1);
    down(x);
	int m=sgt[x].l+sgt[x].r>>1;
    if(r<=m) return q(x<<1,l,r);
    else if(l>m) return q(x<<1|1,l,r);
    else return q(x<<1,l,m)+q(x<<1|1,m+1,r);
}
void ud(int i,int l,int r,ll v){
    if(sgt[i].l==l&&sgt[i].r==r){
		sgt[i].sum+=(r-l+1)*v;
		sgt[i].lz+=v;
		return;
	}
    if(sgt[i].lz) down(i);
    int mid=sgt[i].l+sgt[i].r>>1;
    if(r<=mid)ud(I1,l,r,v);
    else if(l>mid)ud(I2,l,r,v);
    else ud(I1,l,mid,v),ud(I2,mid+1,r,v);
    sgt[i].sum=sgt[I1].sum+sgt[I2].sum;
}
void my(int x,int l,int r,ll val){
    if(sgt[x].l==l&&sgt[x].r==r){
		sgt[x].lz+=val;
		return ;
	}
    sgt[x].sum+=(r-l+1)*val;
    int m=sgt[x].l+sgt[x].r>>1;
    if(r<=m) my(x<<1,l,r,val);
    else if(l>m) my(x<<1|1,l,r,val);
    else my(x<<1,l,m,val),my(x<<1|1,m+1,r,val);
}
void cy(int x,int y,ll val){
    while(top[x]^top[y]){
        if(dep[top[x]]>dep[top[y]]) swap(x,y);
        my(1,ld[top[y]],ld[y],val);
		y=fa[top[y]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    my(1,ld[x],ld[y],val);
}
ll q1(int x,int y){
    ll ret=A-B*dep[x]+ff[x];
    while(top[x]^top[y]){
        if(dep[top[x]]>dep[top[y]]) swap(x,y);
        ret+=q(1,ld[top[y]],ld[y]);
		y=fa[top[y]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return ret+q(1,ld[x],ld[y]);
}
int main(){
    for(scanf("%d",&t);t--;cnt=tot=A=B=0){
        scanf("%d%d",&n,&m);
        memset(h,-1,sizeof(h));
        memset(vis,0,sizeof(vis));
        memset(dep,0,sizeof(dep));
        memset(siz,0,sizeof(siz));
        memset(top,0,sizeof(top));
        memset(rd,0,sizeof(rd));
        memset(ld,0,sizeof(ld));
        memset(ff,0,sizeof(ff));
        for(int i=1;i<n;i++){
        	scanf("%d%d",&x,&y);
			add(x,y);
			add(y,x);
		}
        dep[1]=1;
		p(1);
		dfs(1,1);
		build(1,1,tot);
        while(m--){
            scanf("%d",&op);
            if(op==1){
                scanf("%d%d",&x,&y);
                cy(1,x,2ll),
                A+=y-dep[x];B++;
            }
            else{
                scanf("%d",&x);
                ll val=q1(x,1);
                if(op==2)ff[x]+=min(0ll,val)-val;
                if(op==3)printf("%lld\n",val);
            }
        }
    }
}

猜你喜欢

转载自blog.csdn.net/s260127ljy/article/details/107851512