HDU - 4809 树形dp

找了半天bug 发现把q打成了p。。。

思路:用dp[ i ][ j ][ k ] 表示在 i 这个点 这个点的状态为 j (0:不选 1:属于奇联通块 2:属于偶联通块) 且 奇联通块 - 偶联通块 = k 的方案数, 这个统计的方案数不包括i这个点的联通块。

dp[u][0][x+y]=dp[u][0][x]*dp[v][0][y]+dp[u][0][x]*dp[v][1][y-1]+dp[u][0][x]*dp[v][2][y+1]
dp[u][1][x+y]=dp[u][1][x]*dp[v][0][y]+dp[u][1][x]*dp[v][2][y]+dp[u][2][x]*dp[v][1][y]
dp[u][2][x+y]=dp[u][2][x]*dp[v][0][y]+dp[u][1][x]*dp[v][1][y]+dp[u][2][x]*dp[v][2][y]

 1 #include<bits/stdc++.h>
 2 #define LL long long
 3 #define fi first
 4 #define se second
 5 #define mk make_pair
 6 #define pii pair<int,int>
 7 #define piii pair<int, pair<int,int> >
 8 
 9 using namespace std;
10 
11 const int N = 300 + 7;
12 const int M = 10 + 7;
13 const int inf = 0x3f3f3f3f;
14 const LL INF = 0x3f3f3f3f3f3f3f3f;
15 const int mod = 1e9 + 7;
16 const int zero = 160;
17 
18 int n, dp[N][3][480], tmp[3][480], down[N], up[N];
19 vector<int> edge[N];
20 
21 void inline add(int &x, int y) {
22     x += y; if(x >= mod) x -= mod;
23 }
24 
25 void dfs(int u, int p) {
26     dp[u][0][zero] = 2;
27     dp[u][1][zero] = 1;
28     up[u] = down[u] = 0;
29     for(int v : edge[u]) {
30         if(v == p) continue;
31         dfs(v, u);
32         memset(tmp, 0, sizeof(tmp));
33         for(int p = down[u]; p <= up[u]; p++) {
34             for(int q = down[v] - 1; q <= up[v] + 1; q++) {
35                 if((p + q > n) || (p + q < (-n) / 2 - 2)) continue;
36                 add(tmp[0][p+q+zero], 1ll * dp[u][0][p+zero] * dp[v][0][q+zero] % mod);
37                 add(tmp[0][p+q+zero], 1ll * dp[u][0][p+zero] * dp[v][1][q+zero-1] % mod);
38                 add(tmp[0][p+q+zero], 1ll * dp[u][0][p+zero] * dp[v][2][q+zero+1] % mod);
39 
40                 add(tmp[1][p+q+zero], 1ll * dp[u][1][p+zero] * dp[v][0][q+zero] % mod);
41                 add(tmp[1][p+q+zero], 1ll * dp[u][1][p+zero] * dp[v][2][q+zero] % mod);
42                 add(tmp[1][p+q+zero], 1ll * dp[u][2][p+zero] * dp[v][1][q+zero] % mod);
43 
44                 add(tmp[2][p+q+zero], 1ll * dp[u][2][p+zero] * dp[v][0][q+zero] % mod);
45                 add(tmp[2][p+q+zero], 1ll * dp[u][2][p+zero] * dp[v][2][q+zero] % mod);
46                 add(tmp[2][p+q+zero], 1ll * dp[u][1][p+zero] * dp[v][1][q+zero] % mod);
47 
48                 if(tmp[0][p + q + zero]) down[u] = min(down[u], p + q), up[u] = max(up[u], p + q);
49                 if(tmp[1][p + q + zero]) down[u] = min(down[u], p + q), up[u] = max(up[u], p + q);
50                 if(tmp[2][p + q + zero]) down[u] = min(down[u], p + q), up[u] = max(up[u], p + q);
51             }
52         }
53 
54         for(int i = down[u]; i <= up[u]; i++)
55             for(int j = 0; j < 3; j++)
56                 dp[u][j][i + zero] = tmp[j][i + zero];
57     }
58 }
59 
60 void init() {
61     for(int i = 1; i <= n; i++)
62         edge[i].clear();
63     memset(dp, 0, sizeof(dp));
64 }
65 
66 int main() {
67     while(scanf("%d", &n) != EOF) {
68         init();
69         for(int i = 1; i < n; i++) {
70             int u, v; scanf("%d%d", &u, &v);
71             edge[u].push_back(v);
72             edge[v].push_back(u);
73         }
74         dfs(1, 0);
75         int ans = 0;
76         for(int i = down[1]; i <= up[1]; i++) {
77             add(ans, 1ll * max(0, i) * dp[1][0][i + zero] % mod);
78             add(ans, 1ll * max(0, i + 1) * dp[1][1][i + zero] % mod);
79             add(ans, 1ll * max(0, i - 1) * dp[1][2][i + zero] % mod);
80         }
81         add(ans, 2 * ans % mod);
82         printf("%d\n", ans);
83     }
84     return 0;
85 }
86 /*
87 */

猜你喜欢

转载自www.cnblogs.com/CJLHY/p/9108668.html