牛客练习赛 67:F. 牛妹的苹果树(st表 + 树的直径的合并)

在这里插入图片描述


单看一个询问,就是求仅考虑区间内的点的直径。

树的直径具有可合并的性质,同一棵树上两个区间的直径的两个端点分别为 ( a , b ) (a,b) (a,b), ( c , d ) (c,d) (c,d),那么合并两个区间后的新的直径的端点一定在 a , b , c , d {a,b,c,d} a,b,c,d 中,通过枚举端点计算它们的距离,取最大值可以得到两个区间合并的直径。 其正确性证明和两遍 d f s dfs dfs 求树的直径的证明过程类似。

对于这题,考虑用 s t st st 表预处理区间的直径,对于询问用类似的合并方法合并两个 st 表即可。

查询 l c a lca lca 要用欧拉序和 s t st st 表,不然会TLE。

注意:内置的 log ⁡ 2 ( ) \log2() log2() 函数常数非常大,改成预处理每个数的 log ⁡ 2 \log_2 log2 值可以降低常数。
时空复杂度均为 O ( n log ⁡ 2 n ) O(n\log_2n) O(nlog2n)


代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 3e5 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
int n, q;
int fir[maxn], st[maxn << 1][25], cnt, bin[maxn], lg[2 * maxn];
pii pot[maxn][25];
ll dep[maxn];
struct node {
    
    
	int head[maxn], nxt[maxn << 1], cnt, to[maxn << 1], w[maxn << 1];
	void init() {
    
    
		memset(head,-1,sizeof head);
		cnt = 0;
	}
	void add(int u,int v,int wi) {
    
    
		to[cnt] = v;
		w[cnt] = wi;
		nxt[cnt] = head[u];
		head[u] = cnt++;
	}
} g;
void prework(int u,int fa) {
    
    
	fir[u] = ++cnt; st[cnt][0] = u; 
	pot[u][0] = pii(u,u);
	for (int i = g.head[u]; i + 1; i = g.nxt[i]) {
    
    
		if (g.to[i] == fa) continue;
		dep[g.to[i]] = dep[u] + g.w[i];
		prework(g.to[i],u);
		st[++cnt][0] = u;
	}
}
int calc(int u,int v) {
    
    
	return dep[u] < dep[v] ? u : v;
}
int getlca(int u,int v) {
    
    
	if (fir[u] > fir[v]) swap(u,v);
	int p = lg[fir[v] - fir[u] + 1];
	return calc(st[fir[u]][p],st[fir[v] - bin[p] + 1][p]);
}
ll getdis(ll u,ll v) {
    
    
	int lca = getlca(u,v);
	return dep[u] + dep[v] - dep[lca] - dep[lca];
}
void init() {
    
    
	for (int i = 0; i <= 22; i++)
		bin[i] = 1 << i;
	lg[1] = 0;
	for (int i = 2; i <= cnt; i++)
		lg[i] = lg[i >> 1] + 1;
	for (int i = 1; bin[i] <= cnt; i++)
		for (int j = 1; j + bin[i] - 1 <= cnt; j++)
			st[j][i] = calc(st[j][i - 1],st[j + bin[i - 1]][i - 1]);
	for (int i = 1; bin[i] <= n; i++) 
		for (int j = 1; j + bin[i] - 1 <= n; j++) {
    
    
			pii x = pot[j][i - 1], y = pot[j + bin[i - 1]][i - 1];
			pot[j][i] = pii(x.fir,y.fir);
			if (getdis(x.fir,y.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
				pot[j][i] = pii(x.fir,y.sec);
			if (getdis(x.sec,y.fir) > getdis(pot[j][i].fir,pot[j][i].sec))
				pot[j][i] = pii(x.sec,y.fir);
			if (getdis(x.sec,y.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
				pot[j][i] = pii(x.sec,y.sec);
			if (getdis(x.fir,x.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
				pot[j][i] = pii(x.fir,x.sec);
			if (getdis(y.fir,y.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
				pot[j][i] = pii(y.fir,y.sec);
		}
}
ll query(int l,int r) {
    
    
	int p = lg[r - l + 1];
	pii x = pot[l][p], y = pot[r - bin[p] + 1][p];
	ll ans = getdis(x.fir,y.fir);
	ans = max(ans,getdis(x.fir,y.sec));
	ans = max(ans,getdis(x.sec,y.fir));
	ans = max(ans,getdis(x.sec,y.sec));
	ans = max(ans,getdis(x.fir,x.sec));
	ans = max(ans,getdis(y.fir,y.sec));
	return ans;
}
int main() {
    
    
	g.init();
	scanf("%d%d",&n,&q);
	for (int i = 1; i < n; i++) {
    
    
		int u, v, w; scanf("%d%d%d",&u,&v,&w);
		g.add(u,v,w);
		g.add(v,u,w);
	}
	prework(1,0);
	init();
	while (q--) {
    
    
		int l, r; scanf("%d%d",&l,&r);
		printf("%lld\n",query(l,r));
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41997978/article/details/108032937