SDOI2017 切树游戏(FWT+树链剖分+dp)

题目链接

题目大意

维护一棵树,支持:
1.动态修改某个点权值;
2.查询有多少个联通子树异或值为 p p

题解

这题感觉比较套路,显然可以列出一个dp方程,发现这是FWT异或卷积的形式。具体的,记 f [ i ] f[i] i i 的dp数组的FWT卷积,那么 f [ i ] = b [ v a l [ i ] ] v s o n [ i ] ( f [ v ] + b [ 0 ] ) f[i]=b[val[i]]*\prod_{v\in son[i]} (f[v]+b[0]) ,其中 b [ i ] b[i] 表示只有 i i 一个数字的FWT异或卷积。
于是显然的动态dp就出来了。每个重链维护一颗线段树,我们最终需要求重链上每个区间的FWT卷积之和,因此我们需要维护四个值:FWT卷积,前缀FWT卷积之和,后缀FWT卷积之和,区间FWT卷积之和。然后直接线段树就可以 O ( m l o g n ) O(mlogn) 更新了。

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 10000000;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;

const int MAXT = 70000, MAXN = 30005, MAXM = 130, MOD = 10007;
struct Edge { int to, next; } edge[MAXT];
int rt[MAXN], ls[MAXT], rs[MAXT], val[MAXN], n, m, Q, tot;
int par[MAXN], top[MAXN], head[MAXN], sz[MAXN], wson[MAXN];
int base[MAXM][MAXM], inv[MOD], id[MAXN], ans[MAXM], temp[MAXN];
struct ModInt {
	int num, cnt;
	ModInt() { num = cnt = 1; }
	ModInt& operator=(int x) {
		if (x == 0) num = 1, cnt = 1;
		else num = x, cnt = 0;
		return *this;
	}
	ModInt& operator*=(int x) {
		x %= MOD;
		if (x == 0) ++cnt;
		else (num *= x) %= MOD;
		return *this;
	}
	ModInt& operator/=(int x) {
		x %= MOD;
		if (x == 0) --cnt;
		else (num *= inv[x]) %= MOD;
		return *this;
	}
	int get() { return cnt ? 0 : num; }
} f[MAXN][MAXM];
void addedge(int u, int v) {
	edge[++tot] = (Edge) { v, head[u] };
	head[u] = tot;
}
void fwt(int *a, int n) {
	for (int h = 2; h <= n; h <<= 1) {
		int hh = h >> 1;
		for (int i = 0; i < n; i += h)
		for (int j = i; j < i + hh; j++) {
			int x = a[j], y = a[j + hh];
			a[j] = x + y, a[j + hh] = x - y;
		}
	}
	for (int i = 0; i < n; i++) a[i] = (a[i] % MOD + MOD) % MOD;
}
int ifwt(int *a, int n, int x) {
	if (n == 0) return a[x];
	if (x & n) return ifwt(a, n >> 1, x ^ n) - ifwt(a, n >> 1, x);
	else return ifwt(a, n >> 1, x) + ifwt(a, n >> 1, x ^ n);
}
void dfs1(int u, int fa) {
	++sz[u], par[u] = fa;
	for (int i = head[u]; i; i = edge[i].next) {
		int v = edge[i].to;
		if (v == fa) continue;
		dfs1(v, u), sz[u] += sz[v];
		if (!wson[u] || sz[wson[u]] < sz[v]) wson[u] = v;
	}
}
vector<int> lnk[MAXN];
short pre[MAXT][MAXM], suf[MAXT][MAXM], mul[MAXT][MAXM], sum[MAXT][MAXM];
void pushup(int x) {
	int l = ls[x], r = rs[x];
	for (int k = 0; k < m; k++) {
		pre[x][k] = (pre[l][k] + (int)pre[r][k] * mul[l][k]) % MOD;
		suf[x][k] = (suf[r][k] + (int)suf[l][k] * mul[r][k]) % MOD;
		mul[x][k] = (int)mul[l][k] * mul[r][k] % MOD;
		sum[x][k] = (sum[l][k] + sum[r][k] + (int)suf[l][k] * pre[r][k]) % MOD;
	}
}
int build(const vector<int> &v, int l, int r) {
	int p = ++tot;
	if (l == r) {
		for (int i = 0; i < m; i++)
			pre[p][i] = suf[p][i] = mul[p][i] = sum[p][i] = f[v[l]][i].get();
		return p;
	}
	int mid = (l + r) >> 1;
	ls[p] = build(v, l, mid);
	rs[p] = build(v, mid + 1, r);
	pushup(p); return p;
}
void dfs2(int u, int fa, int t) {
	for (int i = 0; i < m; i++) f[u][i] = base[val[u]][i];
	id[u] = lnk[t].size(); lnk[t].push_back(u); top[u] = t;
	if (wson[u]) dfs2(wson[u], u, t);
	for (int i = head[u]; i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa && v != wson[u]) {
			dfs2(v, u, v);
			for (int j = 0; j < m; j++)
				f[u][j] *= pre[rt[v]][j] + base[0][j];
		}
	}
	if (u == t) {
		rt[u] = build(lnk[u], 0, lnk[u].size() - 1);
		for (int i = 0; i < m; i++) ans[i] += sum[rt[u]][i];
	}
}
void update(const vector<int> &v, int x, int l, int r, int p) {
	if (l == r) {
		for (int i = 0; i < m; i++)
			pre[p][i] = suf[p][i] = mul[p][i] = sum[p][i] = f[v[l]][i].get();
		return;
	}
	int mid = (l + r) >> 1;
	if (x <= mid) update(v, x, l, mid, ls[p]);
	else update(v, x, mid + 1, r, rs[p]);
	pushup(p);
}
char opt[10];
int main() {
	read(n, m);
	inv[1] = 1;
	for (int i = 2; i < MOD; i++)
		inv[i] = MOD - MOD / i * inv[MOD % i] % MOD;
	for (int i = 1; i <= n; i++) read(val[i]);
	for (int i = 1; i < n; i++) {
		int u, v; read(u, v);
		addedge(u, v), addedge(v, u);
	}
	for (int i = 0; i < m; i++) {
		base[i][i] = 1;
		fwt(base[i], m);
	}
	dfs1(1, tot = 0);
	dfs2(1, tot = 0, 1);
	for (int j = 0; j < m; j++) ans[j] %= MOD;
	read(Q);
	while (Q--) {
		reads(opt);
		if (opt[0] == 'Q') {
			int x; read(x); print((ifwt(ans, m >> 1, x) % MOD + MOD) * inv[m] % MOD);
		} else {
			int x, y; read(x, y);
			for (int i = 0; i < m; i++) (f[x][i] /= base[val[x]][i]) *= base[y][i];
			val[x] = y;
			for (; x; x = par[x]) {
				int t = par[top[x]], a = id[x];
				x = top[x];
				if (t > 0) for (int i = 0; i < m; i++)
					f[t][i] /= pre[rt[x]][i] + base[0][i];
				for (int i = 0; i < m; i++) ans[i] -= sum[rt[x]][i];
				update(lnk[x], a, 0, lnk[x].size() - 1, rt[x]);
				for (int i = 0; i < m; i++) ans[i] += sum[rt[x]][i];
				if (t > 0) for (int i = 0; i < m; i++)
					f[t][i] *= pre[rt[x]][i] + base[0][i];
			}
			for (int i = 0; i < m; i++) ans[i] = (ans[i] % MOD + MOD) % MOD;
		}
	}
	ioflush();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/WAautomaton/article/details/87108635