点分治 聪聪可可

传送门

这道题看起来是大规模解决树上路径的问题……那就是点分治啦。

既然我们要求的是树上长度为3的倍数的路径有多少条,那么我们不妨对每条路径的长度取模,这样的话我们实际上就获得了一堆长度为0,1,2的路径。因为点分治的性质,它每次只统计当前子树内经过重心的长度为3的倍数的路径,所以我们在每次统计之后 还要减去子树内的答案。

这样可以很容易的得出每棵子树内的答案是dis[0] * dis[0] + dis[1] * dis[2] * 2,后者很好理解,至于前者,也不难想,就是每两条长度为0或者是3的倍数的路径都可以被合并成长度为3的倍数的路径。这个不需要*2,因为计算的时候每两个点在一次路径计算中会被重复记算。

其余的就是点分治的常规操作。每次找到重心之后,首先求解本棵子树之内的所有结果。注意这里我们不需要去遍历每一棵子树,直接把她视为一个整体,把根结点(当前的重心)的父亲赋成0直接向下dfs。(然而在里面计算的时候还是要遍历子树的哈哈)然后每次这么更新答案就可以了。

看一下代码。

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('\n')
using namespace std;
typedef long long ll;
const int M = 100005;
const int N = 10000005;
 
int read()
{
    int ans = 0,op = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-') op = -1;
        ch = getchar();
    }
    while(ch >='0' && ch <= '9')
    {
        ans *= 10;
        ans += ch - '0';
        ch = getchar();
    }
    return ans * op;
}

struct node
{
    int to,next,v;
}e[M];

int head[M],dis[M],ecnt,n,m,x,y,w,ans,road[4],sum,root,size[M],maxs[M];
bool vis[M];

int gcd(int x,int y)
{
    return !y ? x : gcd(y,x%y);
}

void add(int x,int y,int z)
{
    e[++ecnt].to = y;
    e[ecnt].v = z;
    e[ecnt].next = head[x];
    head[x] = ecnt;
}

void getroot(int x,int fa)
{
    size[x] = 1,maxs[x] = 0;
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(t == fa || vis[t]) continue;
        getroot(t,x);
        size[x] += size[t];
        maxs[x] = max(maxs[x],size[t]);
    }
    maxs[x] = max(maxs[x],sum - size[x]);
    if(maxs[x] < maxs[root]) root = x;
}

void getdis(int x,int fa)
{
    road[dis[x]]++;
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(t == fa || vis[t]) continue;
        dis[t] = (dis[x] + e[i].v) % 3;
        getdis(t,x);
    }
}
int calc(int x,int leng)
{
    int cur = 0;rep(i,0,3) road[i] = 0;//这里相当于是开桶记录
    dis[x] = leng,getdis(x,0);
    cur += (road[1] * road[2]) << 1;
    cur += road[0] * road[0];
    return cur;
}

void solve(int x)
{
    vis[x] = 1,ans += calc(x,0);
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(vis[t]) continue;
        ans -= calc(t,e[i].v);
        sum = size[t],maxs[root = 0] = n;
        getroot(t,0),solve(root);
    }
}

int main()
{
    n = read();
    rep(i,1,n-1) x = read(),y = read(),w = read() % 3,add(x,y,w),add(y,x,w);
    sum = maxs[root] = n,getroot(1,0);
    solve(root); 
    int d = gcd(ans,n*n);
    printf("%d/%d\n",ans/d,n*n/d);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/captain1/p/9649269.html