版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013578420/article/details/82150662
题目大意:
给你一棵节点数为n的无向树, 每条边上有一个栅栏, 等概率的出现
的高度, 有m个僵尸, 出生在
点, 可以闯过低于
高度的栅栏。 一个点是安全的, 当且仅当它不会被任一僵尸到达。 求树中至少有一个安全点的概率。T组数据。
题目思路:
将概率转换为求方案数, 至少一个转换为求一个都没有
现在求整棵树不安全的方案数
考虑树形dp, 设f[i][j]表示子树i内所有点都是不安全的, 子树外皆有可能, 能达到点i的最大僵尸为j的方案数
初始值: i点的最大僵尸至少为出生在该点的最大僵尸
即 设点i出生的最大僵尸为k(没僵尸则为1), 则对于所有j>=k, f[i][j] = 1
考虑子树合并, f[u][a]与f[v][b] v是u的一个孩子
if a == b //说明a可肯定能跨过(u,v)这条边
f[u][a] += f[u][a] * f[v][b] * (a 能跨过(u,v))
if a < b && a 一定不在子树v内 && b 一定在子树v内
f[u][a] += f[u][a] * f[v][b] * (b 不能跨过(u,v))
if a > b && a 一定不在子树v内 && b 一定在子树v内
f[u][a] += f[u][a] * f[v][b] * (a 不能跨过(u, v))
后两种情况可以分别前缀和求即可。
时间复杂度O(nm)
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <bitset>
#define pi pair<int, int>
#define fi first
#define se second
#define mp make_pair
#define ll long long
const int N = (int)2020;
const int mo = 998244353;
using namespace std;
int gi(){
char c = getchar(); int ret = 0;
while (!isdigit(c)) c = getchar();
while (isdigit(c)){
ret = ret * 10 + c - '0';
c = getchar();
}
return ret;
}
ll pw(ll x, int k){
ll ret = 1;
for (; k; k >>= 1, x = x * x % mo)
if (k & 1) ret = ret * x % mo;
return ret;
}
int n, m; pi A[N];
int cnt, lst[N], nxt[N * 2], to[N * 2], L[N * 2], R[N * 2];
bitset<N> in[N]; ll f[N][N], tmp[N], ans, all;
void add(int u, int v, int a, int b){
nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v; L[cnt] = a; R[cnt] = b;
nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u; L[cnt] = a; R[cnt] = b;
}
void dfs(int u, int pre){
for (int i = 1; i <= m; i ++)
if (A[i].se == u) in[u][i] = 1;
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == pre) continue;
dfs(v, u);
in[u] |= in[v];
}
}
void dp(int u, int pre){
int pos = 1; ll sum;
for (int i = 1; i <= m; i ++)
if (A[i].se == u) pos = i;
for (int i = pos; i <= m; i ++) f[u][i] = 1;
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == pre) continue;
dp(v, u);
memcpy(tmp, f[u], sizeof(f[u]));
memset(f[u], 0, sizeof(f[u]));
sum = 0;
for (int a = 1; a <= m; a ++){
int k = max(min(A[a].fi - 1, R[j]) - L[j] + 1, 0);
(f[u][a] += tmp[a] * f[v][a] % mo * k % mo) %= mo;
if (in[v][a]) (sum += f[v][a]) %= mo;
else (f[u][a] += tmp[a] * sum % mo * (R[j] - L[j] + 1 - k) % mo) %= mo;
}
sum = 0;
for (int a = m; a >= 1; a --){
int k = max(min(A[a].fi - 1, R[j]) - L[j] + 1, 0);
if (in[v][a]) (sum += f[v][a] * (R[j] - L[j] + 1 - k) % mo) %= mo;
else (f[u][a] += tmp[a] * sum % mo) %= mo;
}
}
}
int main()
{
int T = gi();
while (T --){
cnt = ans = 0; all = 1;
memset(lst, 0, sizeof(lst));
memset(in, 0, sizeof(in));
memset(f, 0, sizeof(f));
n = gi(); m = gi();
for (int i = 1; i < n; i ++){
int u, v, a, b;
u = gi(), v = gi(), a = gi(), b = gi();
add(u, v, a, b); all = all * (b - a + 1) % mo;
}
for (int i = 1; i <= m; i ++)
A[i].se = gi(), A[i].fi = gi();
sort(A + 1, A + m + 1);
dfs(1, 0);
dp(1, 0);
for (int i = 1; i <= m; i ++)
ans = (ans + f[1][i]) % mo;
ans = (all - ans) % mo;
ans = ans * pw(all, mo - 2) % mo;
if (ans < 0) ans += mo;
printf("%lld\n", ans);
}
return 0;
}