题目大意:给出了一棵树,每棵树有一个权值,定义dist(x,y)为x到y的简单路径经过的结点的个数。定义gcd(x,y) 为x到y的简单路径上所经过的所有点的点权的最大公因数。现在问满足gcd(x,y) 大于 1的条件下,dist(x,y)的最大值是多少。
数据量和数据值均为2e5,因此每个数的素因子不超过7种。
先谈树形dp:路径上gcd大于1表示存在一个素数可以除尽这条路径上所有的整数。可以对路径进行分解,最长路径一定是以某个点为父结点的一条路径,因此我们可以对每个点,求出它沿各棵子树往下最深的距离,然后合并路径,分别求得经过每个点的最长路径,再在其中求最大解(这个点卡了我很久,没有想到如何合并解,套用之前学到的树分治的思想可能会出现两条到负结点的路径来源于同一棵子树,显然这种路径合并后是不合法的,一直纠结如何除去这种不合法的解。而显然通过按路径通过的父结点对答案进行分类,对每一个父节点的情况求不同子树往下的最长链再合并可以解决这种情况。。。因为正解一定被囊括其中,因此能从中得到正解,从这个思想可以看到刚刚纠结的问题其实是子树的解的问题。。而不是该点的解的问题)。
那么如何求解经过每个点往下的最长路径,树形dp还需要求解出所有状态的最优解。我们该用什么来记录最优解呢?通过分析可知gcd大于1表示路径上所有值都能被某个素数整除,那么我可以枚举当前父结点的素因子,一路往下跑,如果能整除就往下,然后更新这个素因子可以跑到的最大深度,那么两个子树的路径如何合并呢?显然如果是通过同一个素因子得出的 就可以合并,我们只要每次跑一棵子树时维护所有因子可以下达的最大深度,在跑之前可以先用这个更新一下最长路径的答案,跑完之后可以更新每个素因子可以下达的最大深度。因此只需要两个容器,一个用来得出每个点的素因子(去重),另一个通过下标来hash每个点素因子可以下达的最大高度。合并的话暴力枚举一下状态就好(两个素因子是否相等),跑起来还是很快的。
(由于我没写代码,不好意思贴别人的,就先放着吧后面来补,不过树分治的我写了);
树分治的思想与树形dp的思想完全一样,区别在于树形dp维护状态最优解时像dp那样用到了前面的解,而树分治是直接递归到底求出每点的解的(这里树分治倒不用像树dp那样枚举各个素因子求最长路径,不过应该也可以,我没试过(可以一试,不过那样的话不如直接写树形dp),点分治的做法可以直接往下求出最深出的gcd值,然后合并的时候再求一遍gcd 即可)
代码与树形dp跑的时间相近,都是460-480ms
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5+10;
int n;
int val[maxn],num[maxn],root,siz,f[maxn],deep[maxn];
bool done[maxn];
vector<int> g[maxn];
map<int,int> mp,tp;
int res = 0;
int gcd(int x,int y){
return y==0?x:gcd(y,x%y);
}
void getroot(int u,int fa){
num[u]=1;
f[u] = 0;
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];
if(v==fa||done[v]) continue ;
getroot(v,u);
num[u]+=num[v];
f[u]=max(f[u],num[v]);
}
f[u]=max(f[u],siz-num[u]);
if(f[u]<f[root]) root = u;
}
void cal(int u,int fa,int n){
if(n==1) return ;
tp[n]=max(tp[n],deep[u]);
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];
if(v==fa||done[v]) continue;
deep[v]=deep[u]+1;
cal(v,u,gcd(n,val[v]));
}
}
int solve(int u){
mp.clear();deep[u]=0;done[u]=true;
map<int,int> :: iterator it1,it2;
if(val[u] == 1) return 0;
mp[val[u]]=1;res=max(res,1);
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];tp.clear();
if(done[v]) continue;
deep[v] = 2;
cal(v,u,gcd(val[v],val[u]));
for(it1 = mp.begin();it1 != mp.end(); it1++)
for(it2 = tp.begin();it2 != tp.end(); it2++)
if(gcd(it1->first,it2->first)>1){
res = max(res,it1->second+it2->second-1);
}
for(it2 = tp.begin();it2 != tp.end(); it2++)
mp[it2->first]=max(mp[it2->first],it2->second);
}
}
void divide(int u,int fa){
solve(u);
for(int i = 0; i < g[u].size(); i++){
if(g[u][i]==fa||done[g[u][i]]) continue;
siz = num[g[u][i]];root = 0;
getroot(g[u][i],u);
divide(root,-1);
}
}
int main(){
scanf("%d",&n);
for(int i = 1; i <= n; i++)
scanf("%d",&val[i]);
for(int i = 1; i < n; i++){
int x,y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
f[0]=0x3f3f3f3f;
root = 0;siz = n;
getroot(1,-1);
divide(root,-1);
printf("%d\n",res);
}