O significado da pergunta: Dada uma árvore de n pontos, cada ponto tem um peso w [i], agora podemos selecionar alguns pontos conectados e adicionar os pesos dos pontos selecionados por este ponto para obter um. Encontre quais valores em [1, m] podem ser expressos como a soma dos pesos dos pontos selecionados. Use a sequência 0101 para a saída.
Ideias:
Considere dividir e conquistar o ponto. Primeiro selecione o centro de gravidade da árvore e considere a resposta que deve ser selecionada neste ponto. Suponha que eu escolha um determinado ponto, então devo escolher o pai deste ponto. Agora comece a repetir a árvore. A cada recursão para um ponto, o bitset desse ponto é inicializado para o bitset representado pelo nó pai e deslocado para a direita por w [x] bits. Seu significado é que se o ponto atual for selecionado, então seu pai deve escolher, isto é, perguntar o conjunto de respostas atuais de seu pai combinada com a resposta do peso deste ponto. Ao olhar para trás, o bitset representado pelo pai é o mesmo que o bitset representado pelo filho. Isso porque depois que os outros filhos do pai posteriormente resolverem o problema, a informação que o pai já obteve é aproveitada, e combinada, o efeito da conexão é alcançado através do pai.
Finalmente, a resposta para a árvore enraizada no centro de divisão e conquista atual é o bitset representado pelo centro de divisão e conquista, portanto, continue a resolvê-lo recursivamente.
Complexidade de tempo O (nlogn⋅mw)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const double eps = 1e-11;
const int N = 3e3 + 10;
const int M = 1e5 + 10;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) {
if(ch == '-') f = -1;
ch = getchar();
}
while(isdigit(ch))
x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
return x * f;
}
///all表示当前子树的结点数
int n, all, sz[N], rt, rt2, a[N], siz[N], maxson[N], maxx;
vector<int>mp[N];
bitset<M>bit[N], ans;
bool vis[N];
void init() {
ans.reset();
for(int i = 1; i <= n; ++i) vis[i] = 0;
for(int i = 1; i <= n; ++i) mp[i].clear();
}
void getrt(int u, int fa) {
siz[u] = 1;
maxson[u] = 0;
int sz = mp[u].size();
for(int i = 0; i < sz; ++i) {
int v = mp[u][i];
if(v == fa || vis[v]) continue;
getrt(v, u);
siz[u] = siz[u] + siz[v];
maxson[u] = max(maxson[u], siz[v]);
}
maxson[u] = max(maxson[u], all - siz[u]);
if((maxson[u] << 1) <= all) rt2 = rt, rt = u;
}
void calc(int u, int fa) {
siz[u] = 1, bit[u] <<= a[u];
int sz = mp[u].size();
for(int i = 0; i < sz; ++i) {
int v = mp[u][i];
if(vis[v] || v == fa) continue;
bit[v] = bit[u];
calc(v, u);
siz[u] += siz[v];
bit[u] |= bit[v];
}
}
void divide(int u) {
vis[u] = 1;
bit[u].reset(), bit[u].set(0);
calc(u, 0);
ans |= bit[u];
int sz = mp[u].size();
for(int i = 0; i < sz; ++i) {
int v = mp[u][i];
if(vis[v]) continue;
rt = rt2 = 0;
maxson[rt] = all = siz[v];
getrt(v, 0);
divide(rt);
}
}
int main() {
int t, m, u, v;
t = read();
while(t--) {
n = read(), m = read();
init();
for(int i = 1; i < n; ++i) {
u = read(), v = read();
mp[u].push_back(v);
mp[v].push_back(u);
}
for(int i = 1; i <= n; ++i) a[i] = read();
rt = rt2 = 0;
maxson[rt] = all = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
for(int i = 1; i <= m; ++i) printf("%d", (int)ans[i]);
printf("\n");
}
return 0;
}