poj3417【数上差分】【树上倍增】【公共祖先lca】

题目传送门

题目大意,n-1条边构成一棵树,给出m个操作,没个操作为俩个点,把俩个点在树上的路径加一,
输出 树中有边属性

先用树上倍增求出f数组,为lca函数准备
俩个点之间的路径+1,可以使用树上差分,cnt【x】+1,cnt【y】+1
cnt【lca(x,y)】-=2; 实现边差分操作;
差分之后求前缀和,cnt【】数组;
使用dfs 记录每个点的子节点之和;

bool vis[N];
void dfs(int x) //求子树的权值和;
{
    vis[x]=true;
    for(int i=h[x];i!=-1;i=ne[i])
    {
        int y=e[i];
        if(vis[y]) continue;
        dfs(y);//求x的分枝中一个的权值和,
        cnt[x]+=cnt[y];
    }
}

#include<iostream>
#include<algorithm>
#include<math.h>
#include<queue>
#include<string.h>
using namespace std;
const int N=1e6;
int h[N],ne[N],e[N],idx;
int cnt[N];
void add(int a,int b)
{
    e[idx]=b;
   // w[idx]=c;
    ne[idx]=h[a];
    h[a]=idx++;
}
int n,m;
//最近公共祖先bfs部分,f[][]数组;
int f[N][30],t,d[N];

void bfs()
{
    queue<int>q;
    d[1]=1;
    q.push(1);
    while(q.size())
    {
        int x=q.front();
        q.pop();

        for(int i=h[x];i!=-1;i=ne[i])
        {
            int y=e[i];
            if(d[y]) continue;

            d[y]=d[x]+1;
            f[y][0]=x;
            for(int i=1;i<=t;i++)
            {
                f[y][i]=f[f[y][i-1]][i-1];
            }
            q.push(y);
        }
    }
}

int lca(int x,int y)
{
    if(d[x]<d[y]) swap(x,y);
    //单独跳;
    for(int i=t;i>=0;i--)
    {
        if(d[f[x][i]]>=d[y]) x=f[x][i];
    }
    if(x==y) return x;
    //一起跳
    for(int i=t;i>=0;i--)
    {
        //未相会,跳
        if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    }
    return f[x][0];
}
bool vis[N];
void dfs(int x)
{
    vis[x]=true;
    for(int i=h[x];i!=-1;i=ne[i])
    {
        int y=e[i];
        if(vis[y]) continue;
        dfs(y);
        cnt[x]+=cnt[y];
    }
}

int main()
{
    memset(h,-1,sizeof h);
    ios::sync_with_stdio(false);
    cin>>n>>m;
    t=(int)(log(n)/log(2))+1;
    for(int i=1;i<n;i++)
    {
        int a,b;
        cin>>a>>b;
        add(a,b);
        add(b,a);
    }
    bfs();
    for(int i=1;i<=m;i++)
    {
        int x,y;
        cin>>x>>y;
        cnt[lca(x,y)]-=2;
        cnt[x]+=1;
        cnt[y]+=1;
    }
    dfs(1);
    int  ans=0;
    for(int i=2;i<=n;i++)
    {
        if(cnt[i]==0) ans+=m;
        if(cnt[i]==1) ans++;
    }
    cout<<ans<<endl;
    return 0;
}

发布了152 篇原创文章 · 获赞 4 · 访问量 3872

猜你喜欢

转载自blog.csdn.net/qq_43716912/article/details/100767671