XSY原创题 友好国度

题目大意

给定一棵树,每个点有点权,求有多少组点对满足两点简单路径上的所有点点权的$gcd=1$。

$n,val_i\leq 10^5$

题解

考虑设$G_i$表示简单路径上所有点点权均为$i$的倍数的点对数。

那么最终答案显然就是$\sum G_i \mu(i)$。

由于求$gcd$,那么点权某一个质因子次数大于$2$是没有意义的,所以$val$最多有$6$个质因子

$(2\times 3\times 5\times 7\times 11\times 13=30030)$。

那么一个数的约数不超过$64$个,那么开$n$个$vector$,$vector_i$存点权是$i$倍数的点的集合。

若求$G_i$,只需要求全部由$vector_i$中的点组成的路径数即可。

那么将这若干个点取出来,在原树构成若干个连通块,那么每一个连通块内任意两点组成路径均能贡献,则答案$=\frac{n(n-1)}{2}$。

最后再求与莫比乌斯函数的点积之和即可。

#include<bits/stdc++.h>
#define debug(x) cerr<<#x<<" = "<<x
#define sp <<"  "
#define el <<endl
#define LL long long
#define M 100020
using namespace std;
namespace IO{
	const int BS=(1<<20)+5; char Buffer[BS],*HD,*TL;
	char Getchar(){if(HD==TL){TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);} return (HD==TL)?EOF:*HD++;}
	int read(){
		int nm=0,fh=1; char cw=Getchar();
		for(;!isdigit(cw);cw=Getchar()) if(cw=='-') fh=-fh;
		for(;isdigit(cw);cw=Getchar()) nm=nm*10+(cw-'0');
		return nm*fh;
	}
}using namespace IO;
int sz[M],n,m,fs[M],nt[M<<1],to[M<<1],val[M],tmp; vector<int>G[M];
int p[M],tot,mu[M],vis[M]; bool isp[M]; LL ans;
#define link(a,b) nt[tmp]=fs[a],fs[a]=tmp,to[tmp++]=b
void init(){
	memset(isp,true,sizeof(isp)),mu[1]=1,isp[1]=false;
	for(int i=2;i<M;i++){
		if(isp[i]) p[++tot]=i,mu[i]=-1;
		for(int j=1;p[j]*i<M&&j<=tot;j++){
			isp[p[j]*i]=false;if(i%p[j]==0) break;mu[p[j]*i]=-mu[i];
		}
	}
}
void dfs(int x,int last,int num){
	vis[x]=num,sz[x]=1;
	for(int i=fs[x];i!=-1;i=nt[i]){
		if(to[i]==last||val[to[i]]%num) continue;
		dfs(to[i],x,num),sz[x]+=sz[to[i]];
	}
}
void solve(int num,LL res=0){
	for(int k=0,TT=G[num].size();k<TT;k++){
		int x=G[num][k];
		if(vis[x]==num) continue;
		dfs(x,0,num),res+=((LL)sz[x]*(LL)(sz[x]-1))>>1;
	} if(res) ans+=mu[num]*res;
}
int main(){
	init(),n=read(),memset(fs,-1,sizeof(fs));
	for(int i=1;i<n;i++){int x=read(),y=read();link(x,y),link(y,x);}
	for(int i=1;i<=n;i++){
		val[i]=read(); 
		for(int j=1;p[j]*p[j]<=p[i];j++) while(val[i]%(p[j]*p[j])==0) val[i]/=p[j];
		for(int j=1;j*j<=val[i];j++){
			if(val[i]%j) continue; G[j].push_back(i);
			if(j*j<val[i]) G[val[i]/j].push_back(i);
		}
	}
	for(int i=1;i<M;i++) solve(i); printf("%lld\n",ans); return 0;
}

猜你喜欢

转载自www.cnblogs.com/OYJason/p/9908738.html