【LOJ3044】「ZJOI2019」Minimax 搜索

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_39972971/article/details/88978027

【题目链接】

【思路要点】

  • 首先考虑如何求出 P r e i Pre_i 表示稳定度在 i i 以内的集合的个数,若求得 P r e i Pre_i ,则有 A n s i = P r e i P r e i 1 Ans_i=Pre_i-Pre_{i-1}
  • 注意到各叶子节点权值不同,其权值最后作为根节点权值的叶子结点是唯一的,记为 k e y key
  • 考虑根节点到 k e y key 的一条链,其中每一个点的权值均为 k e y key 的权值,不难发现若我们不选择将其它点的权值改为 k e y key 的权值,根节点权值变动的充要条件为该链上任意一点的权值产生变化。而在希望根节点权值变动的情况下,选择将其它点的权值改为 k e y key 的权值是不优的。
  • 考虑计算使得该链上各点权值均不变的集合数,再用总的集合数减去之。
  • 对于链上深度为奇数的点 x x ,我们希望将 x x 其余子树的权值改得尽可能大,使得存在某个子树的权值超过 k e y key 的权值,那么,我们会尽可能地将其余子树内能够修改的值改大。将 k e y \leq key 的值记为 0 0 k e y + 1 \geq key+1 的值记为 1 1 ,我们仅仅关心子树内使得根节点权值为 0 / 1 0/1 的叶子结点集合数,这样的集合数可以通过简单 d p dp 得到。对于链上深度为偶数的点可以同样处理,不再赘述。
  • 至此,我们得到了一个 O ( N × ( R L ) ) O(N\times (R-L)) 的算法,可参考下文程序中 224 224 行以后的内容。
  • 注意到各个叶子结点的 d p dp 值至多会有一次变动,因此我们同样可以将其看做对叶子结点 d p dp 值的修改,剩余的问题就是一个简单的动态 d p dp 问题。
  • 将每一棵不在根节点到 k e y key 路径上的子树进行轻重链剖分,可以发现,重链上相邻的两个点的 d p dp 值呈一次函数关系,可以用线段树维护之。
  • 修改一个叶子节点的 d p dp 值至多影响 O ( L o g N ) O(LogN) 条轻边,暴力修改即可。
  • 时间复杂度 O ( N L o g 2 N ) O(NLog^2N)

【代码】

//Program Till Line 223
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
const int P = 998244353;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
struct func {int k, b; };
func operator + (func x, func y) { // fx (fy (x))
	return (func) {1ll * x.k * y.k % P, (1ll * x.k * y.b + x.b) % P};
}
struct SegmentTree {
	struct Node {
		int lc, rc;
		func val;
	} a[MAXN * 2];
	int n, size, root;
	void update(int root) {
		a[root].val = a[a[root].lc].val + a[a[root].rc].val;
	}
	void build(int &root, int l, int r) {
		root = ++size;
		a[root].val = (func) {1, 0};
		if (l == r) return;
		int mid = (l + r) / 2;
		build(a[root].lc, l, mid);
		build(a[root].rc, mid + 1, r);
	}
	void init(int x) {
		n = x;
		root = size = 0;
		build(root, 1, n);
	}
	void modify(int root, int l, int r, int pos, func val) {
		if (l == r) {
			a[root].val = val;
			return;
		}
		int mid = (l + r) / 2;
		if (mid >= pos) modify(a[root].lc, l, mid, pos, val);
		else modify(a[root].rc, mid + 1, r, pos, val);
		update(root);
	}
	void modify(int pos, func val) {
		modify(root, 1, n, pos, val);
	}
	func query(int root, int l, int r, int ql, int qr) {
		if (l == ql && r == qr) return a[root].val;
		int mid = (l + r) / 2;
		if (mid >= qr) return query(a[root].lc, l, mid, ql, qr);
		else if (mid + 1 <= ql) return query(a[root].rc, mid + 1, r, ql, qr);
		else return query(a[root].lc, l, mid, ql, mid) + query(a[root].rc, mid + 1, r, mid + 1, qr);
	}
	int query(int l, int r) {
		assert(l <= r);
		return query(root, 1, n, l, r).b;
	}
} ST;
int power(int x, int y) {
	if (y == 0) return 1;
	int tmp = power(x, y / 2);
	if (y % 2 == 0) return 1ll * tmp * tmp % P;
	else return 1ll * tmp * tmp % P * x % P;
}
vector <int> a[MAXN]; bool leaf[MAXN];
int n, l, r, key[MAXN], sets[MAXN], value[MAXN], depth[MAXN];
bool type[MAXN]; int size[MAXN], son[MAXN];
//type-true: times together, type-false: the other way. 
int timer, dfn[MAXN], up[MAXN], down[MAXN], father[MAXN];
int dp[MAXN], pano, ans[MAXN]; pair <int, int> prod[MAXN];
pair <int, int> operator + (pair <int, int> a, int b) {
	if (b) a.first = 1ll * a.first * b % P;
	else a.second++;
	return a;
}
pair <int, int> operator - (pair <int, int> a, int b) {
	if (b) a.first = 1ll * a.first * power(b, P - 2) % P;
	else a.second--;
	return a;
}
int val(pair <int, int> a) {
	if (a.second) return 0;
	else return a.first;
}
vector <int> mod[MAXN];
void update(int &x, int y) {
	x += y;
	if (x >= P) x -= P;
}
void dfs(int pos, int fa, int from, bool t, bool Max) {
	father[pos] = fa, dfn[pos] = ++timer;
	type[pos] = t, prod[pos] = make_pair(1, 0), up[pos] = from;
	if (leaf[pos]) {
		if (Max) {
			dp[pos] = 2 * (value[pos] <= value[1]);
			if (value[pos] <= value[1]) mod[value[1] - value[pos] + 1].push_back(pos);
		} else {
			dp[pos] = 2 * (value[pos] >= value[1]);
			if (value[pos] >= value[1]) mod[value[pos] - value[1] + 1].push_back(pos);
		}
		ST.modify(dfn[pos], (func) {0, dp[pos]});
	}
	if (son[pos]) {
		dfs(son[pos], pos, from, t ^ true, Max);
		down[pos] = down[son[pos]];
	} else down[pos] = pos;
	for (auto x : a[pos])
		if (x != fa && x != son[pos]) {
			dfs(x, pos, x, t ^ true, Max);
			if (t) prod[pos] = prod[pos] + dp[x]; 
			else prod[pos] = prod[pos] + ((sets[x] - dp[x] + P) % P);
		}
	if (son[pos]) {
		if (t) {
			dp[pos] = 1ll * val(prod[pos]) * dp[son[pos]] % P;
			ST.modify(dfn[pos], (func) {val(prod[pos]), 0});
		} else {
			dp[pos] = 1ll * val(prod[pos]) * (sets[son[pos]] - dp[son[pos]] + P) % P;
			dp[pos] = (sets[pos] - dp[pos] + P) % P;
			ST.modify(dfn[pos], (func) {val(prod[pos]), (sets[pos] - 1ll * val(prod[pos]) * sets[son[pos]] % P + P) % P});
		}
	}
}
void getdfn(int pos, int fa) {
	if (key[pos]) {
		getdfn(key[pos], pos);
		for (auto x : a[pos])
			if (x != fa && x != key[pos]) {
				dfs(x, pos, x, false, depth[pos] & 1);
				pano = 1ll * pano * dp[x] % P, father[x] = 0;
			}
	}
}
void work(int pos, int fa) {
	leaf[pos] = true;
	key[pos] = son[pos] = 0;
	sets[pos] = size[pos] = 1;
	depth[pos] = depth[fa] + 1;
	if (depth[pos] & 1) value[pos] = 1;
	else value[pos] = n;
	for (auto x : a[pos])
		if (x != fa) {
			leaf[pos] = false;
			work(x, pos);
			size[pos] += size[x];
			sets[pos] = 1ll * sets[pos] * sets[x] % P;
			if (size[x] > size[son[pos]]) son[pos] = x;
			if (depth[pos] & 1) {
				if (value[x] > value[pos]) {
					value[pos] = value[x];
					key[pos] = x;
				}
			} else {
				if (value[x] < value[pos]) {
					value[pos] = value[x];
					key[pos] = x;
				}
			}
		}
	if (leaf[pos]) {
		sets[pos] = 2;
		value[pos] = pos;
	}
}
void modify(int pos) {
	ST.modify(dfn[pos], (func) {0, dp[pos] - 1});
	int tmp = ST.query(dfn[up[pos]], dfn[down[pos]]);
	pos = up[pos];
	while (father[pos]) {
		int f = father[pos];
		if (type[f]) prod[f] = (prod[f] + tmp) - dp[pos];
		else prod[f] = (prod[f] + ((sets[pos] - tmp + P) % P)) - ((sets[pos] - dp[pos] + P) % P);
		dp[pos] = tmp; pos = f;
		assert(son[pos]);
		if (type[pos]) ST.modify(dfn[pos], (func) {val(prod[pos]), 0});
		else ST.modify(dfn[pos], (func) {val(prod[pos]), (sets[pos] - 1ll * val(prod[pos]) * sets[son[pos]] % P + P) % P});
		tmp = ST.query(dfn[up[pos]], dfn[down[pos]]);
		pos = up[pos];
	}
	pano = 1ll * pano * tmp % P * power(dp[pos], P - 2) % P;
	dp[pos] = tmp;
}
int main() {
	read(n), read(l), read(r);
	for (int i = 1; i <= n - 1; i++) {
		int x, y; read(x), read(y);
		a[x].push_back(y);
		a[y].push_back(x);
	}
	work(1, 0);
	ST.init(n);
	pano = 1, getdfn(1, 0);
	for (int i = 1; i <= n - 1; i++) {
		for (auto x : mod[i])
			modify(x);
		ans[i] = (sets[1] - pano + P) % P;
	}
	ans[n] = sets[1] - 1;
	for (int i = n; i >= 1; i--)
		update(ans[i], P - ans[i - 1]);
	for (int i = l; i <= r; i++)
		printf("%d ", ans[i]);
	return 0;
}
//Brute Force
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
const int P = 998244353;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
vector <int> a[MAXN];
int n, l, r, two[MAXN], size[MAXN], value[MAXN], depth[MAXN];
void update(int &x, int y) {
	x += y;
	if (x >= P) x -= P;
}
int getMax(int pos, int fa, int delta, int goal) { // value of root smaller than or equal with goal
	if (a[pos].size() == 1) {
		int ans = 0;
		if (value[pos] <= goal) ans++;
		if (value[pos] + delta <= goal) ans++; 
		return ans;
	}
	if (depth[pos] & 1) {
		int ans = 1;
		for (auto x : a[pos])
			if (x != fa) ans = 1ll * ans * getMax(x, pos, delta, goal) % P;
		return ans;
	} else {
		int ans = 1;
		for (auto x : a[pos])
			if (x != fa) ans = 1ll * ans * (two[size[x]] - getMax(x, pos, delta, goal) + P) % P;
		return (two[size[pos]] - ans + P) % P;
	}
}
int getMin(int pos, int fa, int delta, int goal) { // value of root larger than or equal with goal
	if (a[pos].size() == 1) {
		int ans = 0;
		if (value[pos] >= goal) ans++;
		if (value[pos] - delta >= goal) ans++; 
		return ans;
	}
	if (depth[pos] & 1) {
		int ans = 1;
		for (auto x : a[pos])
			if (x != fa) ans = 1ll * ans * (two[size[x]] - getMin(x, pos, delta, goal) + P) % P;
		return (two[size[pos]] - ans + P) % P;
	} else {
		int ans = 1;
		for (auto x : a[pos])
			if (x != fa) ans = 1ll * ans * getMin(x, pos, delta, goal) % P;
		return ans;
	}
}
int getans(int pos, int fa, int delta) {
	int ans = 1;
	for (auto x : a[pos])
		if (x != fa) {
			if (value[x] == value[pos]) ans = 1ll * ans * getans(x, pos, delta) % P;
			else if (depth[pos] & 1) ans = 1ll * ans * getMax(x, pos, delta, value[pos]) % P;
			else ans = 1ll * ans * getMin(x, pos, delta, value[pos]) % P;
		}
	return ans;
}
int getans(int delta) {
	if (delta == 0) return 0;
	if (delta == n) return two[size[1]] - 1;
	return (two[size[1]] - getans(1, 0, delta) + P) % P;
}
void dfs(int pos, int fa) {
	depth[pos] = depth[fa] + 1;
	if (depth[pos] & 1) value[pos] = 1;
	else value[pos] = n;
	bool leaf = true;
	for (auto x : a[pos])
		if (x != fa) {
			leaf = false;
			dfs(x, pos);
			size[pos] += size[x];
			if (depth[pos] & 1) chkmax(value[pos], value[x]);
			else chkmin(value[pos], value[x]);
		}
	if (leaf) {
		size[pos] = 1;
		value[pos] = pos;
	}
}
int main() {
	read(n), read(l), read(r);
	for (int i = 1; i <= n - 1; i++) {
		int x, y; read(x), read(y);
		a[x].push_back(y);
		a[y].push_back(x);
	}
	dfs(1, 0);
	two[0] = 1;
	for (int i = 1; i <= n; i++)
		two[i] = 2ll * two[i - 1] % P;
	static int ans[MAXN];
	for (int i = l - 1; i <= r; i++)
		ans[i] = getans(i);
	for (int i = r; i >= l; i--)
		update(ans[i], P - ans[i - 1]);
	for (int i = l; i <= r; i++)
		printf("%d ", ans[i]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39972971/article/details/88978027
今日推荐