【CodeForces 990G】GCD Counting (点分治)

题目链接

Luogu & CodeForces

题目描述

You are given a tree consisting of n vertices. A number is written on each vertex; the number on vertex i is equal to aiai.

Let's denote the function g(x,y) as the greatest common divisor of the numbers written on the vertices belonging to the simple path from vertex x to vertex y (including these two vertices).

For every integer from 1 to 2⋅10^5you have to count the number of pairs (x,y) (1≤x≤y≤n) such that g(x,y) is equal to this number.

题目翻译

现在有一棵树,每个点都有一个权值,现在求对于每个数,有多少条路径上所有的点的GCD的值和它相等?

解题思路

感觉点分治可做,所以想了一个点分治。

保存一个map,分治树的时候先用根节点到当前节点的的GCD和map中保存的值都进行一次GCD,然后把答案存到ans里,再把当前节点的GCD值存到map里,说白了就是到每个节点都和之前到的节点做一次GCD操作,然后加进去,还能保证刚好只算一次。

但是这样稍微会有一点T,最后T在第95个点。

附上代码:

 1 #include<algorithm>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<cstdio>
 5 #include<queue>
 6 #include<map>
 7 #define LL long long
 8 const int maxn=550000;
 9 const int INF=99999999; 
10 using namespace std;
11 int v[maxn],head[maxn],ne[maxn],to[maxn];
12 int size[maxn],ma[maxn],maxx=-INF;
13 bool vis[maxn];
14 int n,cnt=0,root,tot;
15 map<int,LL>temp;
16 map<int,LL>ans;
17 map<int,LL>::iterator iter1;
18 map<int,LL>::iterator iter2;
19 inline int GCD(register int a,register int b){
20     return a==0?b:GCD(b%a,a);
21 }
22 inline void read(register int &x){
23     x=0; register char ch=getchar();
24     while(ch<'0'||ch>'9')ch=getchar();
25     while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
26 }
27 inline void add(int f,int t){
28     ne[++cnt]=head[f],head[f]=cnt,to[cnt]=t;
29 }
30 inline void get_root(int now,int fa){
31     ma[now]=0,size[now]=1;
32     for(register int i=head[now];i;i=ne[i]){
33         if(to[i]==fa||vis[to[i]])continue;
34         get_root(to[i],now);
35         size[now]+=size[to[i]];
36         ma[now]=max(ma[now],size[to[i]]);
37     }
38     ma[now]=max(ma[now],tot-size[now]);
39     if(ma[now]<ma[root])root=now;
40 }
41 inline void get_dep(int now,int fa,int gcd,int j){
42 //    cout<<"get "<<now<<endl;
43     for(iter1=temp.begin();iter1!=temp.end();iter1++){
44         ans[GCD(gcd,iter1->first)]+=(LL)j*(iter1->second);
45     }
46     temp[gcd]++;
47     ans[gcd]+=(LL)j;
48     for(register int i=head[now];i;i=ne[i]){
49         if(to[i]==fa||vis[to[i]])continue;
50         get_dep(to[i],now,GCD(gcd,v[to[i]]),j);
51 //        cout<<to[i]<<endl;
52     }
53 }
54 inline void solve(int now){
55 //    cout<<"solve "<<now<<endl;
56     temp.clear();
57     get_dep(now,0,v[now],1),vis[now]=1;
58     for(register int i=head[now];i;i=ne[i]){
59         if(vis[to[i]])continue;
60         temp.clear();
61         get_dep(to[i],0,GCD(v[now],v[to[i]]),-1);
62         tot=size[to[i]],root=0;
63         get_root(to[i],0);
64         solve(root);
65     }
66 }
67 int main(){
68     read(n);
69     for(register int i=1;i<=n;i++){
70         read(v[i]);
71         maxx=max(maxx,v[i]);
72     }
73     register int f,t;
74     for(register int i=1;i<n;i++){
75         read(f),read(t);
76         add(f,t),add(t,f);
77     }
78     size[0]=ma[0]=INF,tot=n;
79     get_root(1,0);
80     solve(root);
81     for(register int i=1;i<=maxx;i++){
82         if(ans[i]){
83             printf("%d %lld\n",i,ans[i]);
84         }
85     }
86 }

然后翻了一下网上的题解,发现点分治的方法和我不太一样:

其实,我么没有必要用容斥原理去做,这样会让很多点被多次访问,大大降低了效率。

我们使用两个map,一个表示当前的树,也就是当前根节点的子节点的子树,一个表示当前根节点的子树。

然后每次dfs当前根节点的一个子节点的子树,把GCD值存在第一个map里,然后用第一个map和第二个map用乘法原理计算之后存到ans里。

然后再把第一个map里的值附到第二个map里存起来。

代码

 1 #include<algorithm>
 2 #include<iostream>
 3 #include<cstring>
 4 #include<cstdio>
 5 #include<queue>
 6 #include<map>
 7 #define LL long long
 8 const int maxn=550000;
 9 const int INF=99999999; 
10 using namespace std;
11 int v[maxn],head[maxn],ne[maxn],to[maxn];
12 LL ans[maxn];
13 int size[maxn],ma[maxn],maxx=-INF;
14 bool vis[maxn];
15 int n,cnt=0,root,tot;
16 map<int,LL>temp;
17 map<int,LL>mem;
18 map<int,LL>::iterator iter1;
19 map<int,LL>::iterator iter2;
20 inline int GCD(register int a,register int b){
21     return a==0?b:GCD(b%a,a);
22 }
23 inline void read(register int &x){
24     x=0; register char ch=getchar();
25     while(ch<'0'||ch>'9')ch=getchar();
26     while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
27 }
28 inline void add(int f,int t){
29     ne[++cnt]=head[f],head[f]=cnt,to[cnt]=t;
30 }
31 inline void get_root(int now,int fa){
32     ma[now]=0,size[now]=1;
33     for(register int i=head[now];i;i=ne[i]){
34         if(to[i]==fa||vis[to[i]])continue;
35         get_root(to[i],now);
36         size[now]+=size[to[i]];
37         ma[now]=max(ma[now],size[to[i]]);
38     }
39     ma[now]=max(ma[now],tot-size[now]);
40     if(ma[now]<ma[root])root=now;
41 }
42 inline void get_dep(int now,int fa,int gcd){
43     temp[gcd]++;
44     for(register int i=head[now];i;i=ne[i]){
45         if(to[i]==fa||vis[to[i]])continue;
46         get_dep(to[i],now,GCD(gcd,v[to[i]]));
47     }
48 }
49 inline void get_ans(int now){
50     mem.clear(),mem[v[now]]++,ans[v[now]]++;
51     for(register int i=head[now];i;i=ne[i]){
52         if(vis[to[i]])continue;
53         temp.clear();
54         get_dep(to[i],now,GCD(v[now],v[to[i]]));
55         for(iter1=mem.begin();iter1!=mem.end();iter1++){
56             for(iter2=temp.begin();iter2!=temp.end();iter2++){
57                 ans[GCD(iter1->first,iter2->first)]+=(LL)(iter1->second)*(iter2->second);
58             }
59         }
60         for(iter1=temp.begin();iter1!=temp.end();iter1++){
61             mem[iter1->first]+=iter1->second;
62         }
63     }
64 }
65 inline void solve(int now){
66     vis[now]=1,get_ans(now);
67     for(register int i=head[now];i;i=ne[i]){
68         if(vis[to[i]])continue;
69         tot=size[to[i]],root=0;
70         get_root(to[i],0);
71         solve(root);
72     }
73 }
74 int main(){
75     read(n);
76     for(register int i=1;i<=n;i++){
77         read(v[i]);
78         maxx=max(maxx,v[i]);
79     }
80     register int f,t;
81     for(register int i=1;i<n;i++){
82         read(f),read(t);
83         add(f,t),add(t,f);
84     }
85     size[0]=ma[0]=INF,tot=n;
86     get_root(1,0);
87     solve(root);
88     for(register int i=1;i<=maxx;i++){
89         if(ans[i]){
90             printf("%d %lld\n",i,ans[i]);
91         }
92     }
93 }

猜你喜欢

转载自www.cnblogs.com/Fang-Hao/p/9255839.html