[bzoj4543] [POI2014]Hotel加强版

Description

有一个树形结构的宾馆,n个房间,n-1条无向边,每条边的长度相同,任意两个房间可以相互到达。吉丽要给他的三个妹子各开(一个)房(间)。三个妹子住的房间要互不相同(否则要打起来了),为了让吉丽满意,你需要让三个房间两两距离相同。
有多少种方案能让吉丽满意?

Input

第一行一个数n。
接下来n-1行,每行两个数x,y,表示x和y之间有一条边相连。

Output

让吉丽满意的方案数。

Sample Input

7
1 2
5 7
2 5
2 3
5 6
4 5

Sample Output

5

Solution

先考虑暴力怎么做。

\(f[x][d]\)表示\(x\)的子树里距离\(x\)\(d\)的点的个数,\(g[x][a]\)表示\(x\)子树内距离\(lca\)\(d\)\(lca\)距离\(x\)\(d-a\)的点对个数。

那么,转移就是:
\[ f[x][d]+=f[v][d-1],\\ g[x][d]+=g[v][d+1]+f[x][d]\cdot f[x][d-1] \]
更新答案就是:
\[ ans+=f[x][a-1]\cdot g[v][a]+g[v][a]\cdot f[x][a+1] \]
转移顺序注意下,注意这里的\(f[x]\)\(g[x]\)都是不包括当前子树的。

那个关于\(g\)的转移的后半部分是以\(x\)\(lca\)的点数。

然后长链剖分一波,就\(O(n)\)了。

#include<bits/stdc++.h>
using namespace std;

#define int long long 

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

const int maxn = 1e6+10;

int n,head[maxn],tot,mxdep[maxn],hs[maxn],ans;
int space[maxn<<2],*f[maxn],*g[maxn],*t=space;
struct edge{int to,nxt;}e[maxn<<1];

void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}

void dfs(int x,int fa) {
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=fa) {
            dfs(e[i].to,x);
            mxdep[x]=max(mxdep[x],mxdep[e[i].to]);
            if(mxdep[e[i].to]>=mxdep[hs[x]]) hs[x]=e[i].to;
        }
    mxdep[x]++;
}

void solve(int x,int fa) {
    if(hs[x]) f[hs[x]]=f[x]+1,g[hs[x]]=g[x]-1,solve(hs[x],x);
    f[x][0]=1,ans+=g[x][0];
    for(int i=head[x];i;i=e[i].nxt) {
        int v=e[i].to;if(v==hs[x]||v==fa) continue;
        f[v]=t,t+=mxdep[v]*2+3,g[v]=t,t+=mxdep[v]*2+3;
        solve(e[i].to,x);
        for(int j=0;j<mxdep[v];j++) {
            if(j) ans+=f[x][j-1]*g[v][j];
            ans+=f[v][j]*g[x][j+1];
        }
        for(int j=0;j<mxdep[v];j++) g[x][j+1]+=f[x][j+1]*f[v][j];
        for(int j=0;j<mxdep[v];j++) {
            if(j) g[x][j-1]+=g[v][j];
            f[x][j+1]+=f[v][j];
        }
    }
}

signed main() {
    read(n);
    for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,y);
    dfs(1,0);
    f[1]=t,t+=mxdep[1]*2+3,g[1]=t,t+=mxdep[1]*2+3;
    solve(1,0),write(ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hbyer/p/10319884.html