4.6考试T2

题目描述

桑尼、露娜和斯塔在玩点对游戏,这个游戏在一棵节点数为\(n\)的树上进行。桑尼、露娜和斯塔三人轮流从树上所有未被占有的节点中选取一点,归为己有,轮流顺序为桑尼、露娜、斯塔、桑尼、露娜……。该选取过程直到树上所有点都被选取后结束。

选完点后便可计算每人的得分。点对游戏中有\(m\)个幸运数,在某人占据的节点中,每有一对点的距离为某个幸运数,就得到一分。(树上两点之间的距离定义为两点之间的简单路径的边数)。

你的任务是,假设桑尼、露娜和斯塔每次选取时,都是从未被占有的节点中等概率选取一点,计算每人的期望得分。

题解

首先可以知道每个人可以选择的点数是固定的,可以很容易求得。

那么考虑,假设当前的人可以选\(k\)个点,那么他有\(C_n^k\)种选择方式。

假设一对点的距离满足题目要求,那么他会做出\(C_{n-2}^{k-2}\)次贡献,所以最后的期望就是\(cnt*\frac{C_{n-2}^{k-2}}{C_n^k}\),其中\(cnt\)表示满足题目要求的点对数。

那么有多少符合条件的点对数呢,我么可以通过点分治求出。

#include <iostream>
#include <cstdio>
using namespace std;
const int N = 5e4 + 5;
int n, m, is[N], tot, head[N], mx, cnt, maxx[N], siz[N], rt, vis[N], sum, c1[N], c2[N], mxdep[N], val[15];
struct node{int to, nex;}a[N << 1];
inline int read()
{
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') {x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
	return x * f;
}
void add(int x, int y) {a[++ tot].to = y; a[tot].nex = head[x]; head[x] = tot;}
void dfs(int x, int fa, int dis, int now)
{
	if(is[dis] && x >= now) cnt ++; if(dis >= mx) return;
	for(int i = head[x]; i; i = a[i].nex)
	{
		int y = a[i].to;
		if(y == fa) continue;
		dfs(y, x, dis + 1, now);
	}
}
void work()
{
	for(int i = 1; i <= n; i ++) dfs(i, 0, 0, i);
	double k; k = (n / 3) + (n % 3 != 0);
	printf("%.2f\n", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
	k = (n / 3) + (n % 3 == 2); printf("%.2f\n", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
	k = (n / 3); printf("%.2f\n", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
}
void get_root(int x, int fa)
{
	siz[x] = 1; maxx[x] = 0;
	for(int i = head[x]; i; i = a[i].nex)
	{
		int y = a[i].to;
		if(y == fa || vis[y]) continue;
		get_root(y, x); siz[x] += siz[y];
		maxx[x] = max(maxx[x], siz[y]);
	}
	maxx[x] = max(maxx[x], sum - siz[x]);
	if(maxx[x] < maxx[rt]) rt = x;
}
void get_ans(int x, int fa, int dis)
{
	c2[dis] ++; mxdep[x] = dis;
	for(int i = 1; i <= m; i ++) if(dis <= val[i]) cnt += c1[val[i] - dis];
	for(int i = head[x]; i; i = a[i].nex)
	{
		int y = a[i].to;
		if(y == fa || vis[y]) continue;
		get_ans(y, x, dis + 1);
		mxdep[x] = max(mxdep[x], mxdep[y]);
	}
}
void dfs(int x, int fa)
{
	vis[x] = 1; c1[0] = 1; int em = 0;
	for(int i = head[x]; i; i = a[i].nex)
	{
		int y = a[i].to;
		if(y == fa || vis[y]) continue;
		get_ans(y, x, 1); em = max(em, mxdep[y]);
		for(int j = 0; j <= mxdep[y]; j ++) c1[j] += c2[j], c2[j] = 0;
	}
	for(int j = 0; j <= em; j ++) c1[j] = c2[j] = 0;
	for(int i = head[x]; i; i = a[i].nex)
	{
		int y = a[i].to;
		if(y == fa || vis[y]) continue;
		sum = siz[y]; rt = 0; get_root(y, 0); dfs(rt, 0);
	}
}
void work2()
{
	maxx[0] = 0x3f3f3f3f; sum = n; get_root(1, 0); dfs(rt, 0);
	double k; k = (n / 3) + (n % 3 != 0);
	printf("%.2f\n", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
	k = (n / 3) + (n % 3 == 2); printf("%.2f\n", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
	k = (n / 3); printf("%.2f\n", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
}
int main()
{
	n = read(); m = read();
	for(int i = 1, x; i <= m; i ++) {x = read(); is[x] = 1; mx = max(mx, x); val[i] = x;}
	for(int i = 1, x, y; i <= n - 1; i ++)
	{
		x = read(); y = read();
		add(x, y); add(y, x);
	}
	if(mx <= 100 || n <= 1000) work(); else work2();
	return 0;
}

猜你喜欢

转载自www.cnblogs.com/Sunny-r/p/12641929.html
T2
今日推荐