洛谷 P5666 树的重心

考虑统计每个节点做重心的次数。
先找到树的重心,把它当做根。那么对于树上一个非根结点 \(i\),只有当删去一个不在它子树内的结点时它才有可能成为重心。
\(i\) 所在子树大小为 \(s_i\),其重儿子所在子树大小为 \(g_i\),我们删掉一条边 \((u, v)\) 后分走的子树大小为 \(S\),那么 \(i\) 能够成为重心,当且仅当 \(n - 2s_i \le S \le n - 2g_i\)。区间计数可以用树状数组。为了排除 \((u, v)\) 在子树内的情况需要加一个线段树合并。

#include <cstdio>
#include <cstring>

inline int read(void){
	int res = 0;
	char ch = std::getchar();
	while(ch < '0' || ch > '9')
		ch = std::getchar();
	while(ch >= '0' && ch <= '9')
		res = res * 10 + ch - 48, ch = std::getchar();
	return res;
}

typedef long long ll;

const int MAXN = 3e5 + 19;

class Tarr{
	private:
		int tr[MAXN];
		
		int q(int x){
			int res = 0ll;
			for(; x; x -= x & -x)
				res += tr[x];
			return res;
		}
		
	public:
		int size;
		
		void clear(void){
			std::memset(tr, 0, sizeof tr);
		}
		
		void insert(int x, int k){
			for(; x <= size; x += x & -x)
				tr[x] += k;
		}
		
		int query(int l, int r){
			if(l <= 1)
				return q(r);
			return q(r) - q(l - 1);
		}
};
	
class MergeableSegment{
	private:
		int root[MAXN], ind;
		
		struct Node{
			int ls, rs;
			int val;
		}tr[MAXN << 5];
		
		void push_up(int node){
			tr[node].val = tr[tr[node].ls].val + tr[tr[node].rs].val;
		}
		
		void insert(int &node, int l, int r, int x, const int &val){
			if(!node)
				node = ++ind;
			if(l == r){
				tr[node].val += val;
				return;
			}
			int mid = (l + r) >> 1;
			if(x <= mid)
				insert(tr[node].ls, l, mid, x, val);
			else
				insert(tr[node].rs, mid + 1, r, x, val);
			push_up(node);
		}
		
		int merge(int a, int b, int l, int r){
			if(!a || !b)
				return a + b;
			if(l == r){
				tr[a].val += tr[b].val;
				return a;
			}
			int mid = (l + r) >> 1;
			tr[a].ls = merge(tr[a].ls, tr[b].ls, l, mid);
			tr[a].rs = merge(tr[a].rs, tr[b].rs, mid + 1, r);
			push_up(a);
			return a;
		}
		
		int query(int node, int l, int r, int ql, int qr){
			if(!node)
				return 0;
			if(ql <= l && r <= qr)
				return tr[node].val;
			int mid = (l + r) >> 1;
			int res = 0;
			if(ql <= mid)
				res += query(tr[node].ls, l, mid, ql, qr);
			if(qr > mid)
				res += query(tr[node].rs, mid + 1, r, ql, qr);
			return res;
		}
		
	public:
		int L, R;
		
		void clear(void){
			std::memset(root, 0, sizeof root);
			ind = 0;
			std::memset(tr, 0, sizeof tr);
		}
		
		void insert(int p, int x, const int &val){
			insert(root[p], L, R, x, val);
		}
		
		void merge(int a, int b){
			root[a] = merge(root[a], root[b], L, R);
		}
		
		int query(int p, int l, int r){
			return query(root[p], L, R, l, r);
		}
};

namespace centroid{	
	struct Edge{
		int to, next;
	}edge[MAXN << 1];
	
	int head[MAXN], cnt;
	
	inline void add_edge(int from, int to){
		edge[++cnt].to = to;
		edge[cnt].next = head[from];
		head[from] = cnt;
	}
	
	int root, size[MAXN], gsize[MAXN];
	
	int n;
	ll ans;
	
	Tarr mt1;
	MergeableSegment mt2;
	
	void dfs1(int node, int f){
		size[node] = 1;
		bool flag = true;
		for(int i = head[node]; i; i = edge[i].next)
			if(edge[i].to != f){
				dfs1(edge[i].to, node);
				size[node] += size[edge[i].to];
				if(flag && size[edge[i].to] > n / 2)
					flag = false;
			}
		if(flag && n - size[node] <= n / 2)
			root = node;
	}
	
	void dfs2(int node, int f){
		size[node] = 1, gsize[node] = 0;
		for(int i = head[node]; i; i = edge[i].next)
			if(edge[i].to != f){
				dfs2(edge[i].to, node);
				size[node] += size[edge[i].to];
				if(size[edge[i].to] > gsize[node])
					gsize[node] = size[edge[i].to];
			}
	}
	
	void dfs3(int node, int f){
		for(int i = head[node]; i; i = edge[i].next)
			if(edge[i].to != f){
				mt1.insert(size[edge[i].to], -1);
				mt1.insert(n - size[edge[i].to], 1);
				dfs3(edge[i].to, node);
				mt1.insert(size[edge[i].to], 1);
				mt1.insert(n - size[edge[i].to], -1);
				mt2.merge(node, edge[i].to);
			}
		if(node != root){
			int cnt = mt1.query(n - 2 * size[node], n - 2 * gsize[node])
			 + mt2.query(node, n - 2 * size[node], n - 2 * gsize[node]);
			ans += (ll)node * cnt;
			mt2.insert(node, size[node], -1);
		}
	}
	
	void dfs4(int node, int f){
		mt1.insert(size[node], -1);
		for(int i = head[node]; i; i = edge[i].next)
			if(edge[i].to != f)
				dfs4(edge[i].to, node);
	}
	
	int first, second;
	
	int main(){
		n = read();
		std::memset(head, 0, sizeof head), cnt = 0;
		for(int i = 2; i <= n; ++i){
			int u = read(), v = read();
			add_edge(u, v), add_edge(v, u);
		}
		dfs1(1, 0);
		dfs2(root, 0);
		ans = 0ll;
		mt1.clear();
		mt1.size = n;
		mt2.clear();
		mt2.L = 1, mt2.R = n;
		for(int i = 1; i <= n; ++i)
			if(i != root)
				mt1.insert(size[i], 1);
		dfs3(root, 0);
		first = 0, second = 0;
		for(int i = head[root]; i; i = edge[i].next){
			if(size[edge[i].to] > size[first])
				second = first, first = edge[i].to;
			else if(size[edge[i].to] > size[second])
				second = edge[i].to;
		}
		dfs4(first, root);
		ans += (ll)root * mt1.query(1, n - 2 * size[first]);
		mt1.clear();
		dfs4(first, root);
		ans -= (ll)root * mt1.query(1, n - 2 * size[second]);
		std::printf("%lld\n", ans);
		return 0;
	}
}

int main(){
	for(int T = read(), i = 1; i <= T; ++i)
		centroid::main();
	return 0;
}

猜你喜欢

转载自www.cnblogs.com/feiko/p/13171755.html
今日推荐