P2634 [国家集训队]聪聪可可 点分治模板题

在这里插入图片描述
这道题,我们用一个count数组来记录当前根节点的子树上,每种距离的个数有多少,然后所有点对和模3等于零的情况,即可变为零节点的个数的平方加上1节点个数乘以2节点的再乘个2就统计完成了
所以solve函数就变成了:

int solve(int rt,int len)//每做一次solve 即要重新得到一次dis数组
{
    
    
	Count[0] = Count[1] = Count[2] = 0;
	dis[rt] = len%3;
	get_dis(rt,0,len);
	return Count[0]*Count[0]+Count[1]*Count[2]*2;//得到模3等于0的点
}

其他部分就都是模板的套路了

AC代码

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int MAXN = 2e5+7;

#define inf 0x3f3f3f3f
int head[MAXN],dis[MAXN],maxson[MAXN],siz[MAXN],vis[MAXN];
int SIZE,cnt,num,maxx,root,ans,Count[3];

struct Edge
{
    
    
	int to,w,next;
}edge[MAXN<<1];

void addedge(int u,int v,int w)
{
    
    
	edge[++cnt].to = v;
	edge[cnt].w = w;
	edge[cnt].next = head[u];
	head[u] = cnt;
}

void get_root(int u,int fa)
{
    
    
	siz[u] = 1;
	maxson[u] = 0;
	for(int i = head[u];i;i = edge[i].next){
    
    
		int v = edge[i].to;
		if(vis[v] || fa == v) continue;
		get_root(v,u);
		siz[u] += siz[v];
		maxson[u] = max(maxson[u],siz[v]);
	}
	maxson[u] = max(maxson[u],SIZE-siz[u]);
	if(maxx > maxson[u]) root = u,maxx = maxson[u];
}

void get_dis(int u,int fa,int d)
{
    
    
	dis[++num] = d%3;//直接模3
	Count[dis[num]]++;
	for(int i = head[u];i;i = edge[i].next){
    
    
		int v = edge[i].to;
		if(vis[v] || fa == v) continue;
		get_dis(v,u,d+edge[i].w);
	}
	return ;
}

int solve(int rt,int len)//每做一次solve 即要重新得到一次dis数组
{
    
    
	Count[0] = Count[1] = Count[2] = 0;
	dis[rt] = len%3;
	get_dis(rt,0,len);
	return Count[0]*Count[0]+Count[1]*Count[2]*2;//得到模3等于0的点
}

void Divide(int rt)
{
    
    
	ans = ans + solve(rt,0);//先加一遍根节点的答案
	vis[rt] = 1;
	for(int i = head[rt];i;i = edge[i].next){
    
    
		int v = edge[i].to;
		if(vis[v]) continue;
		ans = ans - solve(v,edge[i].w);
		SIZE = siz[v];
		root = 0;
		maxx = inf;
		get_root(v,rt);
		Divide(root);
	}
}

int gcd(int a,int b)
{
    
    
	if(a < b) swap(a,b);
	int r;
	while(a%b){
    
    
		r = a%b;
		a = b;
		b = r;
	}
	return b;
}

int main()
{
    
    
	int n;
	scanf("%d",&n);
	for(int i = 1;i < n;i ++){
    
    
		int u,v,w;
		scanf("%d%d%d",&u,&v,&w);
		addedge(u,v,w);
		addedge(v,u,w);
	}
	memset(vis,0,sizeof(vis));
	maxx = inf;
	SIZE = n;
	root = 0;
	//printf("1.---\n");
	get_root(1,0);
	//printf("2.---\n");
	Divide(root);
	//printf("---ans = %d\n",ans);
	int sum = n*n;
	printf("%d/%d\n",ans/gcd(ans,sum),sum/gcd(ans,sum));
	return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_45672411/article/details/108509525