【UTR #1】ydc的大树

【UTR #1】ydc的大树

全网唯一一篇题解我看不懂

所以说一下我的O(nlogn)做法:

以1号点为根节点

一个黑点如果有多个相邻的节点出去都能找到最远的黑点,那么这个黑点就是无敌的

所以考虑每个黑点x的最远距离和最远点是否仅在一个“方向”

然后这个方向的一些连续白点割掉可以使得x不高兴

1.如果都在一个方向,假设是x的子树,那就是这个子树最远黑点们的lca到x路径上的任意白点割掉,都可以使得x不高兴

2.如果都在往父亲的方向,找到最浅的点p,使得每个最远黑点到x的路径都经过p,p到x的路径上的任意白点割掉,都可以使得x不高兴

树形DP即可。

struct,记录最远距离、最远的方向个数、决策位置(1的lca或者是2的p)

转移较麻烦

树上差分打标记即可。

求lca,所以O(nlogn)

写了四个dfs。。。

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define fi first
#define se second
#define mk(a,b) make_pair(a,b)
#define numb (ch^'0')
#define pb push_back
#define solid const auto &
#define enter cout<<endl
#define pii pair<int,int>
using namespace std;
typedef long long ll;
template<class T>il void rd(T &x){
    char ch;x=0;bool fl=false;while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);(fl==true)&&(x=-x);}
template<class T>il void output(T x){if(x/10)output(x/10);putchar(x%10+'0');}
template<class T>il void ot(T x){if(x<0) putchar('-'),x=-x;output(x);putchar(' ');}
template<class T>il void prt(T a[],int st,int nd){for(reg i=st;i<=nd;++i) ot(a[i]);putchar('\n');}
namespace Modulo{
const int mod=998244353;
int ad(int x,int y){return (x+y)>=mod?x+y-mod:x+y;}
void inc(int &x,int y){x=ad(x,y);}
int mul(int x,int y){return (ll)x*y%mod;}
void inc2(int &x,int y){x=mul(x,y);}
int qm(int x,int y=mod-2){int ret=1;while(y){if(y&1) ret=mul(x,ret);x=mul(x,x);y>>=1;}return ret;}
}
//using namespace Modulo;
namespace Miracle{
const int N=1e5+5;
const int inf=0x3f3f3f3f;
int n,m;
int b[N];
struct node{
    int nxt,to;
    int val;
}e[2*N];
int hd[N],cnt;
void add(int x,int y,int z){
    e[++cnt].nxt=hd[x];
    e[cnt].to=y;e[cnt].val=z;
    hd[x]=cnt;
}
struct po{
    int mx,cnt,pos;
    po(){mx=-inf,cnt=1,pos=0;}//warning!! -inf
    po(int v,int c,int p){mx=v;cnt=c;pos=p;}
    po friend operator +(po a,po b){
        if(a.mx==b.mx) return po(a.mx,a.cnt+b.cnt,a.pos);
        else if(a.mx>b.mx) return a;
        else return b;
    }
}pr[N],bc[N],f[N],g[N];
int sta[N],top;
int fa[N][17];
int dep[N],vf[N];
void pre(int x){
    dep[x]=dep[fa[x][0]]+1;
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa[x][0]) continue;
        fa[y][0]=x;
        pre(y);
    }
}
void dfs(int x){
    int st=top;
    if(b[x]) f[x]=po(0,1,x);
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa[x][0]) continue;
        vf[y]=e[i].val;
        dfs(y);
        sta[++top]=y;
        pr[y]=f[x];
        f[x]=f[x]+po(f[y].mx+e[i].val,1,f[y].pos);
    }
    if(f[x].mx<0){
        f[x].cnt=1;f[x].pos=0;
    }else{
        if(f[x].cnt>1) f[x].cnt=1,f[x].pos=x;
    }
    po now;
    while(top!=st){
        bc[sta[top]]=now;
        now=now+po(f[sta[top]].mx+vf[sta[top]],1,f[sta[top]].pos);
        --top;
    }
}
void gf(int x){
    if(fa[x][0]){
        int pa=fa[x][0];
        g[x]=g[pa]+pr[x]+bc[x];
        g[x].mx+=vf[x];
        if(g[x].mx<0){
            g[x].cnt=1;g[x].pos=0;
        }else{
            if(g[x].cnt>1) g[x].cnt=1,g[x].pos=pa;
        }
    }
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa[x][0]) continue;
        gf(y);
    }
}
int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(reg j=16;j>=0;--j){
        if(dep[fa[x][j]]>=dep[y]) x=fa[x][j];
    }
    if(x==y) return x;
    for(reg j=16;j>=0;--j){
        if(fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
    }
    return fa[x][0];
}
int tag[N];
int ans,tot;
void fin(int x){
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa[x][0]) continue;
        fin(y);
        tag[x]+=tag[y];
    }
    if(!b[x]){
        if(tag[x]>ans) {
            ans=tag[x];tot=1;
        }else if(tag[x]==ans){
            ++tot;
        }
    }
}
int main(){
    rd(n);rd(m);
    for(reg i=1;i<=m;++i) {
        int x;rd(x);b[x]=1;
    }
    int x,y,z;
    for(reg i=1;i<n;++i){
        rd(x);rd(y);rd(z);
        add(x,y,z);add(y,x,z);
    }
    pre(1);
    dfs(1);
    gf(1);
    for(reg j=1;j<=16;++j){
        for(reg i=1;i<=n;++i){
            fa[i][j]=fa[fa[i][j-1]][j-1];
            // cout<<" fa "<<i<<" "<<j<<" : "<<fa[i][j]<<endl;
        }
    }
    for(reg i=1;i<=n;++i){
        if(b[i]){
            int x=i;
            po now=f[x]+g[x];
            if(now.cnt==1){
                // cout<<" tag? "<<x<<" "<<now.pos<<endl;
                int anc=lca(x,now.pos);
                // cout<<" anc "<<anc<<endl;
                ++tag[x];++tag[now.pos];
                --tag[anc];--tag[fa[anc][0]];
            }
        }
    }
    ans=-inf;
    fin(1);
    ot(ans);ot(tot);
    return 0;
}

}
signed main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
*/

猜你喜欢

转载自www.cnblogs.com/Miracevin/p/10970347.html