考虑统计每个节点做重心的次数。
先找到树的重心,把它当做根。那么对于树上一个非根结点 \(i\),只有当删去一个不在它子树内的结点时它才有可能成为重心。
设 \(i\) 所在子树大小为 \(s_i\),其重儿子所在子树大小为 \(g_i\),我们删掉一条边 \((u, v)\) 后分走的子树大小为 \(S\),那么 \(i\) 能够成为重心,当且仅当 \(n - 2s_i \le S \le n - 2g_i\)。区间计数可以用树状数组。为了排除 \((u, v)\) 在子树内的情况需要加一个线段树合并。
#include <cstdio>
#include <cstring>
inline int read(void){
int res = 0;
char ch = std::getchar();
while(ch < '0' || ch > '9')
ch = std::getchar();
while(ch >= '0' && ch <= '9')
res = res * 10 + ch - 48, ch = std::getchar();
return res;
}
typedef long long ll;
const int MAXN = 3e5 + 19;
class Tarr{
private:
int tr[MAXN];
int q(int x){
int res = 0ll;
for(; x; x -= x & -x)
res += tr[x];
return res;
}
public:
int size;
void clear(void){
std::memset(tr, 0, sizeof tr);
}
void insert(int x, int k){
for(; x <= size; x += x & -x)
tr[x] += k;
}
int query(int l, int r){
if(l <= 1)
return q(r);
return q(r) - q(l - 1);
}
};
class MergeableSegment{
private:
int root[MAXN], ind;
struct Node{
int ls, rs;
int val;
}tr[MAXN << 5];
void push_up(int node){
tr[node].val = tr[tr[node].ls].val + tr[tr[node].rs].val;
}
void insert(int &node, int l, int r, int x, const int &val){
if(!node)
node = ++ind;
if(l == r){
tr[node].val += val;
return;
}
int mid = (l + r) >> 1;
if(x <= mid)
insert(tr[node].ls, l, mid, x, val);
else
insert(tr[node].rs, mid + 1, r, x, val);
push_up(node);
}
int merge(int a, int b, int l, int r){
if(!a || !b)
return a + b;
if(l == r){
tr[a].val += tr[b].val;
return a;
}
int mid = (l + r) >> 1;
tr[a].ls = merge(tr[a].ls, tr[b].ls, l, mid);
tr[a].rs = merge(tr[a].rs, tr[b].rs, mid + 1, r);
push_up(a);
return a;
}
int query(int node, int l, int r, int ql, int qr){
if(!node)
return 0;
if(ql <= l && r <= qr)
return tr[node].val;
int mid = (l + r) >> 1;
int res = 0;
if(ql <= mid)
res += query(tr[node].ls, l, mid, ql, qr);
if(qr > mid)
res += query(tr[node].rs, mid + 1, r, ql, qr);
return res;
}
public:
int L, R;
void clear(void){
std::memset(root, 0, sizeof root);
ind = 0;
std::memset(tr, 0, sizeof tr);
}
void insert(int p, int x, const int &val){
insert(root[p], L, R, x, val);
}
void merge(int a, int b){
root[a] = merge(root[a], root[b], L, R);
}
int query(int p, int l, int r){
return query(root[p], L, R, l, r);
}
};
namespace centroid{
struct Edge{
int to, next;
}edge[MAXN << 1];
int head[MAXN], cnt;
inline void add_edge(int from, int to){
edge[++cnt].to = to;
edge[cnt].next = head[from];
head[from] = cnt;
}
int root, size[MAXN], gsize[MAXN];
int n;
ll ans;
Tarr mt1;
MergeableSegment mt2;
void dfs1(int node, int f){
size[node] = 1;
bool flag = true;
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f){
dfs1(edge[i].to, node);
size[node] += size[edge[i].to];
if(flag && size[edge[i].to] > n / 2)
flag = false;
}
if(flag && n - size[node] <= n / 2)
root = node;
}
void dfs2(int node, int f){
size[node] = 1, gsize[node] = 0;
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f){
dfs2(edge[i].to, node);
size[node] += size[edge[i].to];
if(size[edge[i].to] > gsize[node])
gsize[node] = size[edge[i].to];
}
}
void dfs3(int node, int f){
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f){
mt1.insert(size[edge[i].to], -1);
mt1.insert(n - size[edge[i].to], 1);
dfs3(edge[i].to, node);
mt1.insert(size[edge[i].to], 1);
mt1.insert(n - size[edge[i].to], -1);
mt2.merge(node, edge[i].to);
}
if(node != root){
int cnt = mt1.query(n - 2 * size[node], n - 2 * gsize[node])
+ mt2.query(node, n - 2 * size[node], n - 2 * gsize[node]);
ans += (ll)node * cnt;
mt2.insert(node, size[node], -1);
}
}
void dfs4(int node, int f){
mt1.insert(size[node], -1);
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f)
dfs4(edge[i].to, node);
}
int first, second;
int main(){
n = read();
std::memset(head, 0, sizeof head), cnt = 0;
for(int i = 2; i <= n; ++i){
int u = read(), v = read();
add_edge(u, v), add_edge(v, u);
}
dfs1(1, 0);
dfs2(root, 0);
ans = 0ll;
mt1.clear();
mt1.size = n;
mt2.clear();
mt2.L = 1, mt2.R = n;
for(int i = 1; i <= n; ++i)
if(i != root)
mt1.insert(size[i], 1);
dfs3(root, 0);
first = 0, second = 0;
for(int i = head[root]; i; i = edge[i].next){
if(size[edge[i].to] > size[first])
second = first, first = edge[i].to;
else if(size[edge[i].to] > size[second])
second = edge[i].to;
}
dfs4(first, root);
ans += (ll)root * mt1.query(1, n - 2 * size[first]);
mt1.clear();
dfs4(first, root);
ans -= (ll)root * mt1.query(1, n - 2 * size[second]);
std::printf("%lld\n", ans);
return 0;
}
}
int main(){
for(int T = read(), i = 1; i <= T; ++i)
centroid::main();
return 0;
}