「51nod 算法马拉松31C」彩色树

#include <cstdio>
#include <cstring>
#define R register
#define Max(_A, _B) (_A > _B ? _A : _B)
int F()
{
    R int x; R char ch;
    while(ch = getchar(), ch < '0' || ch > '9'); x = ch - '0';
    while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - '0';
    return x;
}
const int Mod = 1e9 + 7, MaxN = 1e5 + 10;
int n, a[MaxN], Point[MaxN], Next[MaxN << 1], To[MaxN << 1], Ans, Sum[MaxN], q, fa[MaxN];
int point[MaxN], next[MaxN], to[MaxN], cnt, r[MaxN], l[MaxN], Index;
void Insert(R int u, R int v){ next[++cnt] = point[u]; point[u] = cnt; to[cnt] = v; }
void Add(R int u, R int v)
{
    Next[++q] = Point[u]; Point[u] = q; To[q] = v;
    Next[++q] = Point[v]; Point[v] = q; To[q] = u;
}
void DFS(R int u, R int From)
{
    l[u] = r[u] = ++Index;
    Sum[u] = 1; fa[u] = From;
    Insert(a[u], u);
    R int tmp = cnt;
    for(R int j = Point[u]; j; j = Next[j]) 
        if(To[j] != From)
        {
            DFS(To[j], u);
            r[u] = Max(r[u], r[To[j]]);
            Sum[u] += Sum[To[j]];
        }
}
int S[MaxN], tot;
int main()
{
    n = F();
    for(R int i = 1; i <= n; i++) a[i] = F();
    for(R int i = 1; i < n; i++) Add(F(), F());
    DFS(1, 0);
    for(R int i = 1; i <= n; i++)
        if(point[i])
        {
            R int res = n, tmp = 1ll * n * (n - 1) % Mod; tot = 0;
            for(R int j = point[i]; j; j = next[j])
            {
                for(R int k = Point[to[j]]; k; k = Next[k])
                {
                    if(To[k] == fa[to[j]]) continue;
                    R int t = Sum[To[k]];
                    while(tot && r[S[tot]] <= r[To[k]] && l[S[tot]] >= l[To[k]]) t -= Sum[S[tot--]];
                    (tmp -= 1ll * t * (t - 1) % Mod) %= Mod;
                }
                S[++tot] = to[j];
            }
            R int t = n;
            while(tot) t -= Sum[S[tot--]];
            (tmp -= 1ll * t * (t - 1) % Mod) %= Mod;
            (Ans += tmp) %= Mod;
        }
    (Ans += Mod) %= Mod;
    for(R int i = 1; i < n; i++) Ans = 1ll * Ans * i % Mod;
    printf("%d", Ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/steaunk/article/details/79068788