题目链接
题意:给你一个n个点m条边的无向图,求所有的能从s到c再到t的三元组个数,其中每个点在一条路径上至多经过一次。n,m1e5量级。
题解:
首先介绍一下圆方树。
还记得zyb大佬凭借圆方树在APIO拿AU并在SD二轮进队,近年来圆方树也成为了一个热门算法,于是还是很有必要学的。
圆方树的作用是把一个图变成一个树,并且能正确地记录一些信息。
把一个图变成一棵树的方法是我们首先利用tarjan,找到所有的点双连通分量。点双连通分量是值去掉图上任何一个点都不会改变连通性的连通分量。我们对于每一个点双连通分量,我们建一个新的点,表示整个点双,这种点成为方点,并且可能会在方点记录整个点双的信息,原图上的点成为圆点。特殊地,把两个点互相连通的也视作一个点双。对于这些点,我们的连边方式是把每一个圆点与他所在的每一个点双代表的方点连边。当然圆方树上的方点所维护的信息可能不包含在圆方树上处于这个方点父节点的位置的圆点的信息。
用网上别人的一张图,可能比较有助于理解。
这样对一个图建出圆方树之后我们就可以解决一些图上简单路径的问题了。似乎还经常用来解决一些仙人掌上的问题。圆方树题目经常需要在LCA处分圆点和方点分类讨论。
圆方树的一个性质是所有方点相连的点一定都是圆点,所有圆点相连的点一定都是方点。
另一个性质是无论以哪个点为根,圆方树的形态不变。
下面开始写这个题的做法。
我们会发现,如果在一个点双内选择两个点,那么任意地在点双里再选一个点作为中间点,一定能形成一个合法的三元组。如果两个不在同一个点双里的点的路径上经过了某个点双,那么把这个点双上的任意的一个点作为中间点形成的三元组都是合法的。那么我们考虑每一个点作为这个中间点对答案的贡献。
我们发现答案相当于对于每一个
,求有多少
是合法的。
我们对原图建出圆方树,树上圆点的点权是
,方点的点权是度数,也就是点双的大小,这里算是用到一点容斥的思想。这样之后
和
之间的点数是就变成了圆方树上两点之间的权值和。我们设
为
的子树内无序圆点对的路径权值和之和,
为
到
的子树内所有圆点的距离和之和,
为以
为根的子树内圆点的个数,
表示
的权值。我们需要有顺序的转移来保证我们每个量转移前后的正确性。我们有以下转移式:
由于图可能不连通,并且我们求的是无序点对(x,y)的答案,所以最终的结果是
然后就做完了。复杂度是 。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,hed[400010],cnt,ct,hed2[500010],cnt2,z,dfn[400010];
int low[400010],sta[400010],tp;
long long ans,f[400010],sum[400010],g[400010],sz[400010];
struct node
{
int to,next;
}a[400010],aa[400010];
inline void add(int from,int to)
{
a[++cnt].to=to;
a[cnt].next=hed[from];
hed[from]=cnt;
}
inline void add2(int from,int to)
{
aa[++cnt2].to=to;
aa[cnt2].next=hed2[from];
hed2[from]=cnt2;
}
inline void tarjan(int x)
{
low[x]=dfn[x]=++z;
sta[++tp]=x;
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(!dfn[y])
{
tarjan(y);
low[x]=min(low[x],low[y]);
if(dfn[x]<=low[y])
{
++ct;
do
{
add2(ct,sta[tp]);
add2(sta[tp],ct);
--tp;
sz[ct]++;
}while(sta[tp+1]!=y);
add2(ct,x);
add2(x,ct);
++sz[ct];
}
}
else
low[x]=min(low[x],dfn[y]);
}
}
inline void dfs(int x,int fa)
{
for(int i=hed2[x];i;i=aa[i].next)
{
int y=aa[i].to;
if(y!=fa)
dfs(y,x);
}
if(x<=n)
{
sum[x]=1;
g[x]=sz[x];
}
for(int i=hed2[x];i;i=aa[i].next)
{
int y=aa[i].to;
if(y!=fa)
{
f[x]+=f[y]+g[x]*sum[y]+g[y]*sum[x];
g[x]+=g[y]+sum[y]*sz[x];
sum[x]+=sum[y];
}
}
}
int main()
{
scanf("%d%d",&n,&m);
ct=n;
for(int i=1;i<=m;++i)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
for(int i=1;i<=n;++i)
sz[i]=-1;
for(int i=1;i<=n;++i)
{
if(!dfn[i])
{
tarjan(i);
dfs(i,0);
ans+=f[i]*2;
}
}
printf("%lld\n",ans);
return 0;
}