BZOJ3451Normal——点分治+FFT

题目描述
某天WJMZBMR学习了一个神奇的算法:树的点分治!
这个算法的核心是这样的:
消耗时间=0
Solve(树 a)
消耗时间 += a 的 大小
如果 a 中 只有 1 个点
退出
否则在a中选一个点x,在a中删除点x
那么a变成了几个小一点的树,对每个小树递归调用Solve
我们注意到的这个算法的时间复杂度跟选择的点x是密切相关的。
如果x是树的重心,那么时间复杂度就是O(nlogn)
但是由于WJMZBMR比较傻逼,他决定随机在a中选择一个点作为x!
Sevenkplus告诉他这样做的最坏复杂度是O(n^2)
但是WJMZBMR就是不信>_<。。。
于是Sevenkplus花了几分钟写了一个程序证明了这一点。。。你也试试看吧^_^
现在给你一颗树,你能告诉WJMZBMR他的傻逼算法需要的期望消耗时间吗?(消耗时间按在Solve里面的那个为标准)

输入格式
第一行一个整数n,表示树的大小
接下来n-1行每行两个数a,b,表示a和b之间有一条边
注意点是从0开始标号的

输出格式
一行一个浮点数表示答案
四舍五入到小数点后4位
如果害怕精度跪建议用long double或者extended

样例输入
3
0 1
1 2
样例输出
5.6667
提示
n<=30000


答案要求的是点分树上所有点的子树size和(所有的点都是随机的)的期望,根据期望的线性性质,答案等于每个点的期望被算的次数之和。我们考虑点x,他怎样才会对y产生1点贡献呢?只有先选了x,才能保证x对y做出一点贡献。所以做出贡献的概率为 1 d i s ( i , j ) + 1 (因为每个点选到的概率均等,而得先选一个点)。所以最终的答案为 i = 1 n j = 1 n 1 d i s ( i , j ) + 1
那么怎么求这个答案呢?
我们需要知道树上路径长度为x的路径有多少条,我们都知道x为定值时,直接一个简单的点分治即可。但是现在x是所有值,所以我们考虑依旧进行点分治,然后统计每个点到当前根的距离,那么 s u m ( x ) = i = 0 x n u m ( i ) n u m ( x i ) n u m ( i ) 表示到根距离为 i 的点的数量。这是一个十分显然的卷积的形式,我们直接卷积求出sum即可。
#include<bits/stdc++.h>
#define db double
#define MAXN 131072
#define MD 998244353
#define ll long long
using namespace std;
int read(){
    char c;int x;while(c=getchar(),c<'0'||c>'9');x=c-'0';
    while(c=getchar(),c>='0'&&c<='9') x=x*10+c-'0';return x;
}
const db pi=acos(-1.0);
struct comple{
    double x,y;
    comple (double xx=0,double yy=0){x=xx,y=yy;}
    comple operator+(const comple a){return comple(x+a.x,y+a.y);}
    comple operator-(const comple a){return comple(x-a.x,y-a.y);}
    comple operator*(const comple a){return comple(x*a.x-y*a.y,y*a.x+x*a.y);}
}a[MAXN],b[MAXN],w[2][MAXN];
int n,cnt,root,sum,limit=1,l,m,head[MAXN<<1],nxt[MAXN<<1],go[MAXN<<1],buck[MAXN<<1],f[MAXN],siz[MAXN],vis[MAXN],d[MAXN],dep[MAXN],r[MAXN<<1];
ll ans;
int pows(ll a,int b){
    ll base=1;
    while(b){
        if(b&1) base=base*a%MD;
        a=a*a%MD;b/=2;
    }
    return base;
}
void add(int x,int y){
    go[cnt]=y;nxt[cnt]=head[x];head[x]=cnt;cnt++;
    go[cnt]=x;nxt[cnt]=head[y];head[y]=cnt;cnt++;
}
void pre(){
    comple Wn(cos(2*pi/limit),sin(2*pi/limit));
    w[0][0]=w[1][0]=comple(1,0);
    for(int i=1;i<limit;i++) w[1][i]=w[1][i-1]*Wn;
    for(int i=1;i<limit;i++) w[0][i]=w[1][limit-i];
}
void FFT(comple *A,int type){
    for(int i=0;i<limit;i++) if(i<r[i]) swap(A[i],A[r[i]]);
    for(int mid=1;mid<limit;mid<<=1){
        for(int R=mid<<1,j=0;j<limit;j+=R){
            for(int k=0;k<mid;k++){
                comple x=A[j+k],y=w[type==1][limit/(mid<<1)*k]*A[j+k+mid];
                A[j+k]=x+y;A[j+k+mid]=x-y;
            }
        }
    }
}
void getroot(int x,int fa){
    f[x]=0;siz[x]=1;
    for(int i=head[x];i!=-1;i=nxt[i]){
        int to=go[i];
        if(to==fa||vis[to]) continue;
        getroot(to,x);f[x]=max(f[x],siz[to]);
        siz[x]+=siz[to]; 
    }
    f[x]=max(f[x],sum-siz[x]);
    if(f[x]<f[root]) root=x;
}
void getdep(int x,int fa){
    dep[++dep[0]]=d[x];a[d[x]].x++;m=max(m,d[x]);
    for(int i=head[x];i!=-1;i=nxt[i]){
        int to=go[i];
        if(vis[to]||to==fa) continue;
        d[to]=d[x]+1;getdep(to,x);
    }
}
void calc(int x,int w,int type){
    d[x]=w;dep[0]=0;m=0;limit=1;l=0;
    getdep(x,0);m<<=1;
    while(limit<=m) limit<<=1,l++;
    for(int i=1;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    pre();FFT(a,1);
    for(int i=0;i<limit;i++) a[i]=a[i]*a[i];
    FFT(a,-1);
    for(int i=0;i<=m;i++) buck[i]+=type*(int)(a[i].x/limit+0.5);
    for(int i=0;i<limit;i++) a[i].x=a[i].y=0;
}
void solve(int x){
    vis[x]=1;calc(x,0,1);
    for(int i=head[x];i!=-1;i=nxt[i]){
        int to=go[i];
        if(vis[to]) continue;
        calc(to,1,-1);
        sum=siz[to];root=0;
        getroot(to,0);
        solve(root);
    }
}
int main()
{
    n=read();memset(head,-1,sizeof(head));f[0]=2e9;sum=n;
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);
    }
    getroot(1,0);
    solve(root);
    for(int i=0;i<n;i++) ans=(ans+1ll*pows(i+1,MD-2)*buck[i])%MD;
    printf("%d\n",ans);
    return 0;
}
//这里我没有写小数,我用对998244353求逆元的形式。

猜你喜欢

转载自blog.csdn.net/stevensonson/article/details/81456534