树链剖分模板题(处理边权)
注意把边权归到点上,在查询的时候注意不要lca点的权值,另外线段树注意如果当两个点在同一条重链且都在lca点时,直接返回。
树状数组最后计算链时不用归到重儿子上,直接减就好。
segment_tree:
#include <cstdio> #include <cstring> #define Lson l,mid,rt<<1 #define Rson mid+1,r,rt<<1|1 using namespace std; const int M = 1e5+7; typedef long long ll; int cnt,tot,head[M],p[M]; int n,q,pos,pp; int sz[M],top[M],son[M],id[M],rnk[M],f[M],dep[M]; ll a[M]; struct edge { int v,next; ll w; }e[M<<1]; struct Tree { ll sum; }tree[M<<2]; void init(){ tot=cnt=0;memset(head,-1,sizeof(head)); } void add(int u,int v,ll w){ e[++cnt].v=v;e[cnt].next=head[u];e[cnt].w=w; head[u]=cnt; } void fsd(int u,int fa){ for(int i=head[u];~i;i=e[i].next){ int v=e[i].v;ll w=e[i].w; if(v==fa) continue; a[v]=w;p[(i-1)/2+1]=v; fsd(v,u); } return ; } void dfs(int u,int fa,int d){ sz[u]=1;f[u]=fa;son[u]=-1;dep[u]=d; for(int i=head[u];~i;i=e[i].next){ int v=e[i].v; if(v==fa) continue; dfs(v,u,d+1); sz[u]+=sz[v]; if(son[u]==-1||sz[v]>sz[son[u]]) son[u]=v; } return ; } void dfs1(int u,int t){ id[u]=++tot; rnk[tot]=u; top[u]=t; if(son[u]==-1) return ; dfs1(son[u],t); for(int i=head[u];~i;i=e[i].next){ int v=e[i].v; if(v==f[u]||v==son[u]) continue; dfs1(v,v); } return ; } void Pushup(int rt){ tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum; } void build(int l,int r,int rt){ tree[rt].sum=0; if(l==r){ tree[rt].sum=a[rnk[l]]; return ; } int mid=(l+r)>>1; build(Lson); build(Rson); Pushup(rt); } void update(int l,int r,int rt,ll v){ if(l==r){ tree[rt].sum=v; return ; } int mid=(l+r)>>1; if(pp<=mid) update(Lson,v); else update(Rson,v); Pushup(rt); } ll query(int L,int R,int l,int r,int rt){ if(L<=l&&r<=R){ return tree[rt].sum; } ll ans=0; int mid=(l+r)>>1; if(L<=mid) ans+=query(L,R,Lson); if(R>mid) ans+=query(L,R,Rson); return ans; } ll sum(int x,int y){ int fx=top[x],fy=top[y];ll res=0; while(fx!=fy){ if(dep[fx]>dep[fy]){ res+=query(id[fx],id[x],1,n,1); x=f[fx],fx=top[x]; } else{ res+=query(id[fy],id[y],1,n,1); y=f[fy],fy=top[y]; } } if(x==y) return res;//!!!! if(dep[x]<dep[y]) res+=query(id[son[x]],id[y],1,n,1); else res+=query(id[son[y]],id[x],1,n,1); return res; } int main(){ init(); scanf("%d%d%d",&n,&q,&pos); for(int i=1;i<n;i++){ int u,v;ll w; scanf("%d%d%lld",&u,&v,&w); add(u,v,w);add(v,u,w); } fsd(1,-1);a[1]=0ll; dfs(1,-1,1); dfs1(1,1); build(1,n,1); while(q--){ int op; scanf("%d",&op); if(op==0){ int to; scanf("%d",&to); printf("%lld\n",sum(pos,to)); pos=to; } else{ int i;ll w; scanf("%d%lld",&i,&w);pp=id[p[i]]; update(1,n,1,w); } } return 0; }
bit:
#include <cstdio> #include <cstring> #include <iostream> using namespace std; const int M = 1e5+7; typedef long long ll; int n,q,pos; int cnt,head[M],p[M],tot; int sz[M],son[M],dep[M],f[M],top[M],rnk[M],id[M]; ll c[M],a[M]; struct edge{ int v,next; ll w; }e[M<<1]; void init(){ tot=cnt=0;memset(head,-1,sizeof(head));memset(c,0,sizeof(c)); } void add(int u,int v,ll w){ e[++cnt].v=v;e[cnt].next=head[u];e[cnt].w=w; head[u]=cnt; } void fsd(int u,int fa){ for(int i=head[u];~i;i=e[i].next){ int v=e[i].v;ll w=e[i].w; if(v==fa) continue; a[v]=w;p[(i-1)/2+1]=v; fsd(v,u); } return ; } void dfs(int u,int fa,int d){ sz[u]=1;son[u]=-1;f[u]=fa;dep[u]=d; for(int i=head[u];~i;i=e[i].next){ int v=e[i].v; if(v==fa) continue; dfs(v,u,d+1); sz[u]+=sz[v]; if(son[u]==-1||sz[v]>sz[son[u]]) son[u]=v; } return ; } void dfs1(int u,int t){ top[u]=t; id[u]=++tot; rnk[tot]=u; if(son[u]==-1) return ; dfs1(son[u],t); for(int i=head[u];~i;i=e[i].next){ int v=e[i].v; if(v==f[u]||v==son[u]) continue; dfs1(v,v); } return ; } void update(int x,ll v){ for(;x<=n;x+=x&(-x)){ c[x]+=v; } } ll query(int x){ ll res=0; for(;x;x-=x&(-x)){ res+=c[x]; } return res; } ll sum(int x,int y){ int fx=top[x],fy=top[y];ll res=0; while(fx!=fy){ if(dep[fx]>dep[fy]){ res+=query(id[x])-query(id[fx]-1); x=f[fx],fx=top[x]; } else{ res+=query(id[y])-query(id[fy]-1); y=f[fy],fy=top[y]; } } if(dep[x]<dep[y]) res+=query(id[y])-query(id[x]); else res+=query(id[x])-query(id[y]); return res; } int main(){ init(); scanf("%d%d%d",&n,&q,&pos); for(int i=1;i<n;i++){ int u,v;ll w; scanf("%d%d%lld",&u,&v,&w); add(u,v,w);add(v,u,w); } fsd(1,-1);a[1]=0ll; dfs(1,-1,1); dfs1(1,1); for(int i=1;i<=n;i++){ update(id[i],a[i]); } while(q--){ int op; scanf("%d",&op); if(op==0){ int to; scanf("%d",&to); printf("%lld\n",sum(pos,to)); pos=to; } else{ int i;ll w; scanf("%d%lld",&i,&w); ll k=query(id[p[i]])-query(id[p[i]]-1); update(id[p[i]],-k); update(id[p[i]],w); } } return 0; }