SDOI2018 荣誉称号(树形dp)

题目链接

题目大意

给定一棵完全二叉树,要求任意一条不拐弯长度为 k + 1 k+1 的链(即从某个点不断往上跳 k k p a r e n t parent ),满足链上所有点的和是 m m 的倍数。
n 1 0 7 , k 10 n\le 10^7,k\le 10

题解

考虑两条相邻的链 a 0 , a 2 , . . . , a k a_0,a_2,...,a_k a 1 , a 2 , . . . , a k + 1 a_1,a_2,...,a_{k+1} ,由于他们的和都是 m m 的倍数,则显然有 a 0 a k + 1   ( m o d   m ) a_0\equiv a_{k+1}~(mod~m)
也就是说我们最后只需要考虑编号 < 2 k + 1 <2^{k+1} 的那些点。但是我们需要预处理 g [ i ] [ j ] g[i][j] 表示点 i ( i < 2 k + 1 ) i(i<2^{k+1}) 及所有需要和点 i i 相等的点,全部改成 j j 所需要的最小代价。
如果暴力的话,也就是暴力枚举 j j 再暴力枚举所有点,复杂度是 O ( n m ) O(nm) 的。但是我们可以把权值 m o d   m mod~m 相同的点一起处理,具体的,令 a l l [ i ] [ j ] all[i][j] 表示点 i ( i < 2 k + 1 ) i(i<2^{k+1}) 及所有需要和点 i i 相等的点中,权值 m o d   m = j mod~m=j 的所有点的单次修改代价之和。
于是我们把问题转化为了 m m 个点,暴力枚举 m m 次,这样就可以在 O ( 2 k m 2 ) O(2^km^2) 的时间内解决了。当然前缀和优化一下可以做到 O ( 2 k m ) O(2^km) ,但我懒。
然后我们就可以在只有 2 k + 1 1 2^{k+1}-1 个节点上的树进行dp了。 f [ i ] [ j ] f[i][j] 表示所有叶节点到 i i 的权值和 m o d   m mod~m 均为 j j 的最小修改代价,暴力枚举当前点上的值转移即可,复杂度 O ( 2 k m 2 ) O(2^km^2)
于是整道题的复杂度就是 O ( n + 2 k m 2 ) O(n+2^km^2) 了。

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 1 << 23;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
#define cls(a) memset(a, 0, sizeof(a))

const int MAXN = 10000005, MAXK = 2050, MAXM = 205;
ll f[MAXK][MAXM], g[MAXK][MAXM], all[MAXK][MAXM];
int bel[MAXN], T, n, m, K, Q;
unsigned int SA, SB, SC; int pp, A, B;
unsigned int rng61(){
    SA ^= SA << 16;
    SA ^= SA >> 5;
    SA ^= SA << 1;
    unsigned int t = SA;
    SA = SB;
    SB = SC;
    SC ^= t ^ SA;
    return SC;
}
void gen(){
	cls(g), cls(all), memset(f, 0x3f, sizeof(f));
	read(n, K, m, pp, SA, SB, SC, A, B);
	Q = (1 << (++K)) - 1;
    for (int i = 1; i <= Q; i++) bel[i] = i;
    for (int i = Q + 1; i <= n; i++) bel[i] = bel[i >> K];
    for (int i = 1; i <= pp; i++) {
    	int a, b; read(a, b);
    	all[bel[i]][a % m] += b;
    }
    for (int i = pp + 1; i <= n; i++){
        int a = rng61() % A + 1;
        int b = rng61() % B + 1;
        all[bel[i]][a % m] += b;
    }
    for (int i = 1; i <= Q; i++) {
    	for (int j = 0; j < m; j++) {
    		for (int k = 0; k <= j; k++) g[i][j] += all[i][k] * (j - k);
    		for (int k = j + 1; k < m; k++) g[i][j] += all[i][k] * (j + m - k);
    	}
    }
}
inline void upd(ll &x, ll y) { x = min(x, y); }
int main() {
	for (read(T); T--;) {
		gen();
		for (int i = Q; i > 0; i--) {
			int ls = i << 1, rs = i << 1 | 1;
			if (ls > Q) for (int j = 0; j < m; j++) f[i][j] = g[i][j];
			else if (rs > Q) {
				for (int j = 0; j < m; j++)
				for (int k = 0; k < m; k++)
					upd(f[i][(j + k) % m], f[ls][k] + g[i][j]);
			} else {
				for (int j = 0; j < m; j++)
				for (int k = 0; k < m; k++)
					upd(f[i][(j + k) % m], f[ls][k] + f[rs][k] + g[i][j]);
			}
		}
		printf("%lld\n", f[1][0]);
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/WAautomaton/article/details/87363963