POJ3417 LCA+树dp

http://poj.org/problem?id=3417

题意:先给出一棵无根树,然后下面再给出m条边,把这m条边连上,然后每次你能毁掉两条边,规定一条是树边,一条是新边,问有多少种方案能使树断裂。

我们考虑加上每一条新边的情况,当一条新边加上之后,原本的树就会成环,环上除了所有的树边要断的话必然要砍掉这条新边才可行。

每一条新边成的环就是u - lca(u,v) - v,对每一条边的覆盖次数++

考虑所有的树边,被覆盖 == 0的时候,意味着单独砍掉这条树边即可,其他随便选一个新边就是一种方案,贡献值 += M;

被覆盖 == 1的时候,意味着砍掉这条树边必须砍掉另一条与他匹配的新边,贡献值 ++

被覆盖 >= 2的时候,这条树边被砍掉是没有意义的,因为不可能同时砍掉两条以上的新边

下面的问题就变成了如何求每一条边的被覆盖次数,我们只要对dp[lca] -= 2,dp[u]++,dp[v]++从根节点向下推,到叶子节点之后回溯,更新dp值即可

这就变成了一个喜闻乐见的树dp、

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
#define For(i, x, y) for(int i=x;i<=y;i++)  
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))  
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Scl(x) scanf("%lld",&x);  
#define Pri(x) printf("%d\n", x)
#define Prl(x) printf("%lld\n",x);  
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long  
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second 
typedef vector<int> VI;
const double eps = 1e-9;
const int maxn = 1e5 + 10;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7; 
int N,M,tmp,K; 
int head[maxn],tot,cnt;
bool vis[maxn];
int F[maxn * 2],P[maxn],rmq[maxn * 2];
struct Edge{
    int to,next;
}edge[maxn * 2];
int dp[maxn];
LL sum;
struct ST{
    int dp[maxn * 2][20];
    int mm[maxn * 2];
    void init(int n){
        mm[0] = -1;
        for(int i = 1; i <= n ; i ++){
            mm[i] = ((i & (i - 1)) == 0)?mm[i - 1] + 1:mm[i - 1];
            dp[i][0] = i;
        }
        for(int j = 1; j <= mm[n]; j ++){
            for(int i = 1; i + (1 << j) - 1 <= n ; i ++){
                dp[i][j] = rmq[dp[i][j - 1]] < rmq[dp[i + (1 << (j - 1))][j - 1]]?dp[i][j - 1]:dp[i + (1 << (j - 1))][j - 1];
            }
        }
    }
    int query(int a,int b){
        if(a > b) swap(a,b);
        int k = mm[b - a + 1];
        return rmq[dp[a][k]] <= rmq[dp[b - (1 << k) + 1][k]]?dp[a][k]:dp[b - (1 << k) + 1][k];
    }
}st;
void init(){
    Mem(head,-1);
    tot = 0;
}
void add(int u,int v){
    edge[tot].next = head[u];
    edge[tot].to = v;
    head[u] = tot++;
}
void dfs(int u,int pre,int dep){
    F[++cnt] = u;
    rmq[cnt] = dep;
    P[u] = cnt;
    for(int i = head[u]; ~i; i = edge[i].next){
        int v = edge[i].to;
        if(v == pre ) continue;
        dfs(v,u,dep + 1);
        F[++cnt] = u;
        rmq[cnt] = dep;
    }
}
void LCA_init(int root){
    cnt = 0;
    dfs(root,root,0);
    st.init(2 * N - 1);
}
int lca(int u,int v){
    return F[st.query(P[u],P[v])];
}
int dfs2(int x,int last){
    for(int i = head[x]; ~i ; i = edge[i].next){
        int to = edge[i].to;
        if(to == last) continue;
        dfs2(to,x);
        dp[x] += dp[to];
        if(dp[to] == 1) sum++;
        else if(!dp[to]) sum += M;
    }
    return dp[x];
}
int main()
{
    Sca2(N,M);
    init();
    For(i,1,N - 1){
        int u,v; Sca2(u,v);
        add(u,v); add(v,u);
    }
    LCA_init(1);
    For(i,1,M){
        int u,v; Sca2(u,v);
        dp[u]++; dp[v]++; dp[lca(u,v)] -= 2;
    }
    dfs2(1,-1);
    Prl(sum);
    #ifdef VSCode
    system("pause");
    #endif
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Hugh-Locke/p/9664450.html