LCA-倍增模板
O( n l o g n nlogn nlogn)预处理,O( l o g n logn logn)查询
//lca 倍增template
void dfs(int cur,int fath)
{
if(st[cur]) return ;
st[cur] = 1;
dep[cur] = dep[fath]+1;
fa[cur][0] = fath;
for(int i = 1; i <= lg[dep[cur];i++)
fa[cur][i] = fa[fa[cur][i-1]][i-1];
for(int i = h[cur];i != -1; i = e[i].ne)
{
int v = e[i].to;
dfs(v,fath);
/* 树上边权前缀和
if(v != fath)
{
dis[v] = dis[cur] + e[i].w;
dfs(v,cur);
}*/
}
}
int lca(int a,int b)
{
if(dep[a]>dep[b]) swap(a,b);
while(dep[a] != dep[b])
b = fa[b][lg[dep[b]-dep[a]]];
if(a==b) return a;
for(int k = lg[dep[a]];k>=0;k--)
if(fa[a][k] != fa[b][k])
a = fa[a][k],b=fa[b][k];
return fa[a][0];
}
int main()
{
for(int i = 2; i <= n; i++)
lg[i] = lg[i>>1]+1;
...
}
树上前缀和
设 s u m i sum_i sumi 表示结点 i i i到根节点的权值总和。
然后:
- 若是点权 x , y x,y x,y路径上的和为 s u m x + s u m y − s u m l c a − s u m f a l c a sum_x+sum_y-sum_{lca}-sum_{fa_{lca}} sumx+sumy−sumlca−sumfalca 。
- 若是边权 x , y x,y x,y路径上的和为 s u m x + s u m y − 2 s u m l c a sum_x+sum_y-2sum_{lca} sumx+sumy−2sumlca 。
l c a lca lca的求法参见 最近公共祖先。
例题
loj - 0134- Dis 模板题
ACcode:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#include<stack>
#include<string>
#include<utility>
#include<cmath>
#include<vector>
#include<functional>//使用 greater<int>();
using namespace std;
typedef long long ll;
typedef pair<int,int> pll;
const int INF = 0x3f3f3f3f;
const int N = 1e4+100;
int n,m;
int h[N],cnt = 0,dis[N];
int dep[N],fa[N][25],st[N];
int lg[N];
struct node
{
int to,w,ne;
}e[N<<1];
void init(){
for(int i = 2; i <= 1e4+10;i++)
lg[i] = lg[i/2]+1;
memset(h,-1,sizeof(h));
}
void add(int u,int v,int w)
{
e[cnt].to = v;
e[cnt].w = w;
e[cnt].ne = h[u];
h[u] = cnt++;
}
void dfs(int cur,int fath)
{
if(st[cur]) return ;
st[cur] = 1;
dep[cur] = dep[fath] + 1;
fa[cur][0] = fath;
for(int i = 1; i <= lg[dep[cur]];i++)
fa[cur][i] = fa[fa[cur][i-1]][i-1];
for(int i = h[cur];i != -1;i=e[i].ne)
{
int v = e[i].to;
if(v != fath)
{
dis[v] = dis[cur] + e[i].w;
dfs(v,cur);
}
}
}
int LCA(int a,int b)
{
if(dep[a] > dep[b]) swap(a,b);
while(dep[a] != dep[b])
b = fa[b][lg[dep[b]-dep[a]]];
if(a==b) return a;
for(int k = lg[dep[a]];k >= 0;k --)
if(fa[a][k] != fa[b][k])
a = fa[a][k],b = fa[b][k];
return fa[a][0];
}
int main()
{
init();
cin >> n >> m;
int u,v,w;
int num = n-1;
while(num--)
{
cin >> u >> v >> w;
add(u,v,w);
add(v,u,w);
}
dfs(1,0);
int s,t;
while(m--)
{
cin >> s >> t;
int lca = LCA(s,t);
int res = dis[s]+dis[t] - 2*dis[lca];
cout << res << endl;
}
return 0;
}
/*
author:nttttt;
add oil!
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#include<stack>
#include<string>
#include<utility>
#include<cmath>
#include<vector>
#include<functional>//使用 greater<int>();
using namespace std;
typedef long long ll;
typedef pair<int,int> pll;
const int INF = 0x3f3f3f3f;
const int N = 3e5+100;
int n,m,k;
int h[N],cnt = 0;
ll dep[N],fa[N][30],st[N],ans[N][51],node[N];
int lg[N];
int mod = 998244353;
struct Node
{
int to,ne;
}e[N<<1];
void init(){
for(int i = 2; i <= 3e5;i++)
lg[i] = lg[i/2]+1;
memset(h,-1,sizeof(h));
node[0] = 1;
}
void add(int u,int v)
{
e[cnt].to = v;
e[cnt].ne = h[u];
h[u] = cnt++;
}
void dfs(int cur,int fath)
{
if(st[cur]) return ;
st[cur] = 1;
dep[cur] = dep[fath]+1;
fa[cur][0] = fath;
for(int i = 1; i <= lg[dep[cur]];i++)
fa[cur][i] = fa[fa[cur][i-1]][i-1];
for(int i = h[cur];i != -1;i=e[i].ne)
{
int v = e[i].to;
if(v != fath)
{
for(int j = 1; j <= 50; j++) node[j] = node[j-1]*dep[cur]%mod;
for(int j = 1; j <= 50; j++) ans[v][j] = (node[j] + ans[cur][j])%mod;
dfs(v,cur);
}
}
}
int LCA(int a,int b)
{
if(dep[a] > dep[b]) swap(a,b);
while(dep[a] != dep[b])
b = fa[b][lg[dep[b]-dep[a]]];
if(a==b) return a;
for(int k = lg[dep[a]];k >= 0;k --)
if(fa[a][k] != fa[b][k])
a = fa[a][k],b = fa[b][k];
return fa[a][0];
}
int main()
{
init();
scanf("%d",&n);
int u,v;
int num = n-1;
while(num--)
{
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
int s,t;
scanf("%d",&m);
while(m--)
{
scanf("%d%d%d",&s,&t,&k);
int lca = LCA(s,t);
ll res;
res = (ans[s][k]+ans[t][k] - ans[lca][k] - ans[fa[lca][0]][k])%mod;
printf("%lld\n",(res%mod+mod)%mod);
}
return 0;
}