Fish eating fruit Shenyang network game (tree dp)

Fish eating fruit

\[ Time Limit: 1000 ms \quad Memory Limit: 262144 kB \]

The meaning of problems

The title is generally given a meaning tree, find the distance between each point, then there is a distance of the distance \ (\ mod 3 \) position, the output sum.

Thinking

Make two \ (dp \) arrays and two auxiliary \ (dp \) array.
\ (dp1 [i] [j ] \) represents the \ (I \) as the starting point to respective points down from \ (\ mod 3 \) after as \ (J \) of the sum of the distances.
\ (cnt1 [i] [j ] \) represented by \ (I \) as the starting point to respective points down from \ (\ mod 3 \) after \ (J \) number of nodes.
\ (dp2 [i] [j ] \) represents the \ (I \) behind the start point of the distance to each step up \ (\ mod 3 \) after as \ (J \) of the sum of the distances.
\ (cnt2 [i] [j ] \) represented by \ (I \) starting from the step up to the respective point distance \ (\ mod 3 \) after \ (J \) number of nodes.


For two \ (dp \) were run again \ (dfs \)


For \ (dp1 \) the better deal directly down \ (dfs \)
with \ (u \) answer from the beginning is equal to \ (v \) the answer to this one start plus side \ (w \) contributions can be
\ [dp1 [u] [( j + w) \% 3] = \ sum (dp1 [v] [j] + cnt1 [v] [j] * w) \\ cnt1 [u] [(j + w) \% 3] = \ sum cnt1 [v] [j] \]


For \ (dp2 \) would be more trouble, need to use \ (fa \) contributed upwards of node plus \ (fa \) down the contribution of nodes minus \ (fa \) node to \ (u \) go contribution. These nodes is \ (u \) all nodes after going up step can walk in. Thus calculated actual distance and the sum of the number of nodes, then \ (u \) to begin the transfer.
Set \ (FAW \) from to \ (U \) to \ (FA \) is the path length
to calculate the true number of nodes:
\ [C [J] = CNT2 [FA] [J] + CNT1 [FA] [J] \\ c [(j + faw)
\% 3] - = cnt1 [u] [j] \] calculated from the sum of the true:
\ [D [J] = DP2 is [FA] [J] + DP1 [FA] [ j] \\ d [(j +
faw) \% 3] - = dp1 [u] [0] + cnt1 [u] [j] * faw \] is the last \ (DP2 is \) can use \ (D \) and \ (c \) has been
\[ dp2[u][(j+faw)\%3] = c[j]*faw+d[j] \\ cnt2[u][(j+faw)\%3] = c[j] \]

/*************************************************************** 
    > File Name    : a.cpp
    > Author       : Jiaaaaaaaqi
    > Created Time : Mon 16 Sep 2019 08:55:33 PM CST
 ***************************************************************/

#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_map>
#define  lowbit(x)  x & (-x)
#define  mes(a, b)  memset(a, b, sizeof a)
#define  fi         first
#define  se         second
#define  pb         push_back
#define  pii        pair<int, int>

typedef unsigned long long int ull;
typedef long long int ll;
const int    maxn = 1e5 + 10;
const int    maxm = 1e5 + 10;
const ll     mod  = 1e9 + 7;
const ll     INF  = 1e18 + 100;
const int    inf  = 0x3f3f3f3f;
const double pi   = acos(-1.0);
const double eps  = 1e-8;
using namespace std;

int n, m;
int cas, tol, T;

vector< pii > vv[maxn];
ll cnt1[maxn][3], cnt2[maxn][3];
ll dp1[maxn][3], dp2[maxn][3];

void dfs1(int u, int fa) {
    cnt1[u][0] = 1;
    for(auto i : vv[u]) {
        int v = i.fi, w = i.se;
        if(v == fa) continue;
        dfs1(v, u);
        dp1[u][(0+w)%3] += (cnt1[v][0]*w%mod+dp1[v][0])%mod;
        dp1[u][(1+w)%3] += (cnt1[v][1]*w%mod+dp1[v][1])%mod;
        dp1[u][(2+w)%3] += (cnt1[v][2]*w%mod+dp1[v][2])%mod;
        for(int j=0; j<3; j++)  dp1[u][j] %= mod;
        cnt1[u][(0+w)%3] += cnt1[v][0];
        cnt1[u][(1+w)%3] += cnt1[v][1];
        cnt1[u][(2+w)%3] += cnt1[v][2];
    }
}

void dfs2(int u, int fa) {
    if(u!=1) {
        int faw;
        for(auto i : vv[u]) {
            if(i.fi == fa) {
                faw = i.se;
                break;
            }
        }
        int c[3] = { 0 };
        c[0] = cnt2[fa][0]+cnt1[fa][0];
        c[1] = cnt2[fa][1]+cnt1[fa][1];
        c[2] = cnt2[fa][2]+cnt1[fa][2];
        c[(0+faw)%3] -= cnt1[u][0];
        c[(1+faw)%3] -= cnt1[u][1];
        c[(2+faw)%3] -= cnt1[u][2];
        ll d[3] = { 0 };
        d[0] = (dp2[fa][0]+dp1[fa][0])%mod;
        d[1] = (dp2[fa][1]+dp1[fa][1])%mod;
        d[2] = (dp2[fa][2]+dp1[fa][2])%mod;
        d[(0+faw)%3] = ((d[(0+faw)%3] - (cnt1[u][0]*faw%mod+dp1[u][0])%mod+mod)%mod+mod)%mod;
        d[(1+faw)%3] = ((d[(1+faw)%3] - (cnt1[u][1]*faw%mod+dp1[u][1])%mod+mod)%mod+mod)%mod;
        d[(2+faw)%3] = ((d[(2+faw)%3] - (cnt1[u][2]*faw%mod+dp1[u][2])%mod+mod)%mod+mod)%mod;
        
        dp2[u][(0+faw)%3] = (c[0]*faw%mod+d[0])%mod;
        dp2[u][(1+faw)%3] = (c[1]*faw%mod+d[1])%mod;
        dp2[u][(2+faw)%3] = (c[2]*faw%mod+d[2])%mod;
        cnt2[u][(0+faw)%3] += c[0];
        cnt2[u][(1+faw)%3] += c[1];
        cnt2[u][(2+faw)%3] += c[2];
    }
    for(auto i : vv[u]) {
        int v = i.fi, w = i.se;
        if(v == fa) continue;
        dfs2(v, u);
    }
}

int main() {
    // freopen("in", "r", stdin);
    while(~scanf("%d", &n)) {
        for(int i=1; i<=n; i++) {
            vv[i].clear();
        }
        mes(dp1, 0), mes(dp2, 0);
        mes(cnt1, 0), mes(cnt2, 0);
        for(int i=1, u, v, w; i<n; i++) {
            scanf("%d%d%d", &u, &v, &w);
            u++, v++;
            vv[u].pb(make_pair(v, w));
            vv[v].pb(make_pair(u, w));
        }
        dfs1(1, 0);
        dfs2(1, 0);
        // for(int i=1; i<=n; i++) {
        //     for(int j=0; j<3; j++) {
        //         printf("dp1[%d][%d] = %lld, cnt1[%d][%d] = %lld\n", i, j, dp1[i][j], i, j, cnt1[i][j]);
        //     }
        // }
        // cout << "-----------------" << endl;
        // for(int i=1; i<=n; i++) {
        //     for(int j=0; j<3; j++) {
        //         printf("dp2[%d][%d] = %lld, cnt2[%d][%d] = %lld\n", i, j, dp2[i][j], i, j, cnt2[i][j]);
        //     }
        // }
        ll ans0, ans1, ans2;
        ans0 = ans1 = ans2 = 0;
        for(int i=1; i<=n; i++) {
            ans0 = (ans0+dp1[i][0]+dp2[i][0])%mod;
            ans1 = (ans1+dp1[i][1]+dp2[i][1])%mod;
            ans2 = (ans2+dp1[i][2]+dp2[i][2])%mod;
        }
        printf("%lld %lld %lld\n", ans0, ans1, ans2);
    }
    return 0;
}

Guess you like

Origin www.cnblogs.com/Jiaaaaaaaqi/p/11530717.html