杭电2586 戳我
题意
给出一个加权无根树,多组询问每两点之间的距离。
输入格式
第一行T,代表有T组测试
后面代表每一组测试的格式
输入树点的个数n以及输入询问的次数m,再输入n-1条边以及该边的权值以及x,y,代表询问x,y之间的距离
数据范围:T<=10 n<=40000 m<=200
输出格式
输出所有测试询问点对之间的距离
暴力做法
对于每一组询问,x,y,直接dfs一遍求出x到y的距离即可。每一组测试时间复杂度:O(n*m),8e6 明显也能过
代码:
#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false)
#define maxn 40005
struct edge{
int to,next,len;
}G[maxn*2];
int head[maxn],num;
void add(int from,int to,int len)
{
G[++num].next=head[from];
G[num].to=to;
G[num].len=len;
head[from]=num;
}
int dis[maxn];
void init()
{
memset(head,-1,sizeof(head));
num=0;
}
void dfs(int fa,int cur)
{
for(int i=head[cur];i!=-1;i=G[i].next)
{
int v=G[i].to;
int len=G[i].len;
if(v!=fa)
{
dis[v]=dis[cur]+G[i].len;
dfs(cur,v);
}
}
}
int main()
{
IOS;
int T;
cin>>T;
while(T--)
{
init();
int n,m;
cin>>n>>m;
for(int i=0;i<n-1;i++)
{
int x,y,z;
cin>>x>>y>>z;
add(x,y,z),add(y,x,z);
}
for(int i=0;i<m;i++)
{
int x,y;
cin>>x>>y;
dis[x]=0;
dfs(x,x);
cout<<dis[y]<<"\n";
}
}
return 0;
}
一、LCA求树上距离
解法:设u,v的最近公共祖先为lca,记录每一点到根节点(这里根节点可以随便设一个,从该点开始搜索即可)的距离,用dir数组表示,则两点之间的距离公式如下:dir[u]+dir[v]-2*dir[lca],即u到根节点的距离+v到根节点的距离 - 2倍u,v公共祖先到根节点的距离。
如下图所示,即因为u,v到根节点有公共部分,减去这一部分即为u,v的距离。
下面贴上两种求法,欧拉序+ST表O(nlogn)、tarjan+并查集 O(kn)
Dfs+ST表
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define maxn 40005
#define IOS ios::sync_with_stdio(false)
inline void read(int& x){
char ch = getchar();
x = 0;
for(;ch < '0' || ch > '9';ch = getchar());
for(;ch >= '0' && ch <= '9';ch = getchar()) x = x*10+(ch-'0');
}
struct edge
{
int next,to,dis;
};
edge ff[2*maxn];
int head[2*maxn],num;
void add(int from,int to,int dis)
{
ff[++num].next=head[from];
ff[num].to=to;
ff[num].dis=dis;
head[from]=num;
}
int Eu[2*maxn],id;//记录欧拉序列
int dd[maxn];//每一点的深度
int hId[maxn];//记录每一点对应欧拉序列第一次出现的位置(映射)
int dir[maxn];//记录根节点到每个点的距离
bool vis[maxn];
void dfs(int x,int d)
{
dd[x]=d;
Eu[++id]=x;
hId[x]=id;
vis[x]=true;
for(int i=head[x];i;i=ff[i].next)
{
if(vis[ff[i].to]==false)
{
dir[ff[i].to]=dir[x]+ff[i].dis;
dfs(ff[i].to,d+1);
Eu[++id]=x;
}
}
}
int f[2*maxn][32];
int res[2*maxn][32];
void ST_Init(int n)
{
for(int i=1; i<=n; i++)
{
f[i][0]=dd[Eu[i]];
res[i][0]=Eu[i];
}
int depth=(int)(log2(n));
for(int j=1; j<=depth; j++)
for(int i=1; i<=n-(1<<j)+1; i++)
{
if(f[i][j-1]<f[i+(1<<(j-1))][j-1])
{
f[i][j]=f[i][j-1];
res[i][j]=res[i][j-1];
}
else
{
f[i][j]=f[i+(1<<(j-1))][j-1];
res[i][j]=res[i+(1<<(j-1))][j-1];
}
}
}
int ST_Query(int l,int r)
{
int p=(int)(log2(r-l+1));
if(f[l][p]<f[r-(1<<p)+1][p])
return res[l][p];
else
return res[r-(1<<p)+1][p];
}
int main()
{
//IOS;
int t;
read(t);
while(t--)
{
memset(vis,false,sizeof(vis));
memset(dir,false,sizeof(dir));
memset(head,0,sizeof(head));
num=0,id=0;
int n,m;
read(n),read(m);
for(int i=0; i<n-1; i++)
{
int x,y,dis;
read(x),read(y),read(dis);
add(x,y,dis);
add(y,x,dis);
}
dfs(1,1);
ST_Init(id);
for(int i=0; i<m; i++)
{
int x,y,l,r;
read(x),read(y);
if(hId[x]<hId[y])
l=hId[x],r=hId[y];
else
l=hId[y],r=hId[x];
int lca=ST_Query(l,r);
printf("%d\n",dir[x]+dir[y]-2*dir[lca]);
}
}
return 0;
}
Targan+并查集
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define maxn 40005
struct node
{
int value,order;
};
struct edge
{
int next,to,dis;
};
edge ff[2*maxn];
int head[2*maxn],num;
void add_edge(int from,int to,int dis)
{
ff[++num].next=head[from];
ff[num].to=to;
ff[num].dis=dis;
head[from]=num;
}
vector<node>qq[maxn];
bool vis[maxn];
int yy[maxn],ans[maxn],dir[maxn];
int in_qq[maxn][2];//记录访问的左右端点
inline void read(int& x)
{
char ch = getchar();
x = 0;
for(; ch < '0' || ch > '9'; ch = getchar());
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x*10+(ch-'0');
}
int find(int x)
{
return yy[x]==x?x:yy[x]=find(yy[x]);
}
void dfs(int from,int to)
{
vis[to]=true;
for(int i=head[to]; i; i=ff[i].next)
if(ff[i].to!=from)
{
dir[ff[i].to]=dir[to]+ff[i].dis;
dfs(to,ff[i].to);
}
for(int i=0; i<qq[to].size(); i++)
if(vis[qq[to][i].value]&&ans[qq[to][i].order]==0)
ans[qq[to][i].order]=find(qq[to][i].value);
yy[to]=find(from);
}
int main()
{
int t;
read(t);
while(t--)
{
num=0;
memset(head,0,sizeof(head));
memset(dir,0,sizeof(dir));
memset(vis,false,sizeof(vis));
int n,m;
read(n),read(m);
for(int i=0; i<n-1; i++)
{
int x,y,dis;
read(x),read(y),read(dis);
add_edge(x,y,dis);
add_edge(y,x,dis);
}
for(int i=0; i<m; i++)
{
int x,y;
read(x),read(y);
in_qq[i+1][0]=x,in_qq[i+1][1]=y;
qq[x].push_back((node)
{
y,i+1
});
qq[y].push_back((node)
{
x,i+1
});
}
for(int i=0; i<=n; i++)yy[i]=i;
dfs(1,1);
for(int i=1; i<=m; i++)
{
printf("%d\n",dir[in_qq[i][0]]+dir[in_qq[i][1]]-2*dir[ans[i]]);
}
}
return 0;
}
二、点分治求树上距离
点分治的基本步骤:找到当前树的重心,然后dfs一遍求出所有点到重心的距离,很容易发现u到重心的距离+v到重心的距离就是u,v的距离。但是会发现这些得到的答案不完全正确,因为可能u,v可能在同一子树上:
A为重心,求B,C之间的距离,这样求得的答案就是a+b+2*c。所以还需要继续对每一个子树分治,重复上面的过程,对每次求出两点间的距离取最小值。需要求O(log(n))次重心,每次求重心O(n),复杂度为O(nlog(n)。
#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false)
const int N = 40005, INF = 0x7f7f7f7f;
int n, m,head[N], num, tot;//tot记录当前子树点数
int dis[N], flag[N];
int visited[N],id;
//dis记录子树每一点到根节点的距离,flag用于删除根节点
int size[N], Max[N], root;
int ans[N];
struct node{
int first,second;
};
node query[N];
struct edge
{
int next, to,len;
} G[N * 2];
void add(int from, int to,int len)
{
G[++num].next = head[from];
G[num].to = to;
G[num].len=len;
head[from] = num;
}
void input(void)
{
scanf("%d",&m);
for (int i = 1; i<n; i++)
{
int x, y,z;
scanf("%d%d%d",&x,&y,&z);
add(x, y,z), add(y, x,z);
}
for(int i=0; i<m; i++)
{
ans[i]=INF;
int x,y;
scanf("%d%d",&x,&y);
query[i]=node{x,y};
}
}
void dp(int fa, int cur)//求树的重心
{
size[cur] = 1, Max[cur] = 0;
for (int i = head[cur]; i; i = G[i].next)
{
int v = G[i].to;
if (flag[v] || v == fa) continue;
dp(cur, v);
size[cur] += size[v];
if(Max[cur]<size[v])
Max[cur] = size[v];
}
if(Max[cur]<tot-size[cur])
Max[cur] = tot - size[cur];
if (Max[root] > Max[cur]) root = cur;
}
void dfs(int fa, int cur)
{
visited[id++]=cur;
for (int i = head[cur]; i; i = G[i].next)
{
int v = G[i].to;
if (v == fa || flag[v]) continue;
dis[v]=dis[cur]+G[i].len;
dfs(cur, v);
}
}
void calc(int x, int len)
{
id=0;
dis[x] = len;
dfs(0, x);
/*cout<<"root="<<root<<",";
for(int i=1;i<=n;i++)
cout<<dis[i]<<" ";
cout<<"\n";*/
for(int i=0; i<m; i++)
{
int x=query[i].first,y=query[i].second;
if(dis[x]!=-1&&dis[y]!=-1&&dis[x]+dis[y]<ans[i])
ans[i]=dis[x]+dis[y];
}
for(int i=0; i<id; i++)
{
int id=visited[i];
dis[id]=-1;
}
}
void divide(int x)
{
flag[x] = true;//删去根节点
calc(x, 0);
//cout<<"ans="<<ans<<"\n";
for (int i = head[x]; i; i = G[i].next)
{
int y = G[i].to;
if (flag[y]) continue;
//ans -= calc(y, G[i].len);//点对在同一子树的情况
tot = size[y], root = 0;
dp(0, y);
divide(root);
}
}
void reset(void)
{
num = 0;
memset(head,0,sizeof(head));
memset(flag,0,sizeof(flag));
memset(dis,-1,sizeof(dis));
tot = n;
root = 0, Max[0] = INF;
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
reset();
input();
dp(0, 1);
divide(root);
for(int i=0; i<m; i++)
printf("%d\n",ans[i]);
}
return 0;
}