洛谷 P2495/BZOJ 2286 消耗战题解

洛谷 P2495/BZOJ 2286 消耗战题解

题意

给定一棵有\(n\)个结点的树(\(1\leq n\leq250000\)),其中每条边都为双向边,第\(i\)条边将\(u_i\)\(v_i\)相连,权值为\(c_i\)(\(1\leq u,v\leq n,1\leq c\leq 10^5\))。接下来有\(m\)(\(m\geq 1\))个不相关的询问,每次给定一个正整数\(k\)(\(1\leq k\leq n-1\)),以及\(k\)个点\(h_1,h_2,……,h_k\),求将这\(k\)个点全部切断与1号点的联系最少需要切断总权值为多少的边。(\(\Sigma k \leq 500000\))

题解

看到\(\Sigma k\)这个东西,很明显想到虚树。又注意到每个询问互相之间是独立的,所以对于每一个询问,先利用LCA和单调栈将整棵树抽象。接下来就是考虑树形DP的时间了。我们令\(f_x\)为将\(x\)与其子树内的所有指定点切断的最小代价。将方程分为两种情况:

1、\(son_x\)不是指定点:\(f_x+=max(f_{son_x},c_{x->son_x})\)

2、\(son_x\)是关键点:\(f_x+=c_{x->son_x}\)

将树抽象后直接DP即可。

Code

#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void read(T &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9') x=x*10+(ch^48),ch=getchar();
    x*=f;
    return;
}
template<typename T>
void write(T x)
{
    if(x<0) putchar('-'),x=-x;
    if(x>=10) write(x/10);
    putchar(x%10+'0');
    return;
}
const int MAXN=250010,MAXLG=20;
int n,m;
int tot=1;
struct node{
    int v,c;
}edge[MAXN*2];
int nxt[MAXN*2];
int hd[MAXN];
inline void add_edge(int u,int v,int c)
{
    edge[tot].v=v,edge[tot].c=c;
    nxt[tot]=hd[u];
    hd[u]=tot++;
}
int d[MAXN];
bool book[MAXN];
struct node2{
    int pos,val;
}LCA[MAXN][MAXLG+1];
queue<int> q;
inline void prework()
{
    q.push(1),book[1]=1,d[1]=1;
    while(!q.empty())
    {
        int x=q.front();
        q.pop();
        for(int i=hd[x];i;i=nxt[i])
        if(!book[edge[i].v])
        {
            int y=edge[i].v,c=edge[i].c;
            LCA[y][0].pos=x,LCA[y][0].val=c;
            d[y]=d[x]+1;
            q.push(y),book[y]=1;
            for(int j=1;j<=MAXLG;j++)
            LCA[y][j].pos=LCA[LCA[y][j-1].pos][j-1].pos,
            LCA[y][j].val=min(LCA[y][j-1].val,LCA[LCA[y][j-1].pos][j-1].val);
        }
    }
    memset(book,0,sizeof(book));
}
inline int lca(int x,int y,int &val)
{
    val=INT_MAX;
    if(d[x]<d[y]) swap(x,y);
    for(int i=MAXLG;i>=0;i--)
    if(d[LCA[x][i].pos]>=d[y]) val=min(val,LCA[x][i].val),x=LCA[x][i].pos;
    if(x==y) return x;
    for(int i=MAXLG;i>=0;i--)
    if(LCA[x][i].pos!=LCA[y][i].pos)
    {
        val=min(val,LCA[x][i].val),x=LCA[x][i].pos;
        val=min(val,LCA[y][i].val),y=LCA[y][i].pos;
    }
    val=min(val,LCA[x][0].val),val=min(val,LCA[y][0].val);
    return LCA[x][0].pos;
}
int dfn;
int id[MAXN];
void dfs(int p)
{
    id[p]=++dfn;
    for(int i=hd[p];i;i=nxt[i])
    if(!id[edge[i].v]) dfs(edge[i].v);
}
struct node3{
    int tot=1;
    int edge[MAXN*2],nxt[MAXN*2],hd[MAXN],w[MAXN*2];
    inline void add_edge(int u,int v,int c)
    {
        edge[tot]=v,w[tot]=c;
        nxt[tot]=hd[u],hd[u]=tot++;
    }
}t;
bool cmp(int a,int b)
{
    return id[a]<id[b];
}
int k;
int h[MAXN];
int st[MAXN],tp;
inline void build()
{
    int tmp;
    t.tot=1,tp=0;
    sort(h+1,h+k+1,cmp);
    st[++tp]=1,t.hd[1]=0;
    for(int i=1;i<=k;i++) book[h[i]]=1;
    for(int i=1;i<=k;i++)
    {
        int l=lca(st[tp],h[i],tmp);
        if(l!=st[tp])
        {
            while(id[l]<id[st[tp-1]])
            {
                lca(st[tp-1],st[tp],tmp);
                t.add_edge(st[tp-1],st[tp],tmp);
                tp--;
            }
            if(id[l]>id[st[tp-1]]) t.hd[l]=0,lca(l,st[tp],tmp),t.add_edge(l,st[tp],tmp),st[tp]=l;
            else lca(l,st[tp],tmp),t.add_edge(l,st[tp],tmp),tp--;
        }
        t.hd[h[i]]=0,st[++tp]=h[i];
    }
    for(int i=1;i<tp;i++) lca(st[i],st[i+1],tmp),t.add_edge(st[i],st[i+1],tmp);
}
long long f[MAXN];
void dp(int p)
{
    f[p]=0;
    for(int i=t.hd[p];i;i=t.nxt[i]) dp(t.edge[i]);
    for(int i=t.hd[p];i;i=t.nxt[i])
    if(!book[t.edge[i]]) f[p]+=min(f[t.edge[i]],(long long)t.w[i]);
    else f[p]+=(long long)t.w[i];
}
inline void init()
{
    for(int i=1;i<=k;i++) book[h[i]]=0;
}
int main()
{
    read(n);
    for(int i=1;i<n;i++)
    {
        int u,v,c;
        read(u),read(v),read(c);
        add_edge(u,v,c),add_edge(v,u,c);
    }
    prework();
    dfs(1);
    read(m);
    for(int i=1;i<=m;i++)
    {
        read(k);
        for(int j=1;j<=k;j++) read(h[j]);
        build();
        dp(1);
        init();
        write(f[1]),putchar('\n');
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/xiaoh105/p/12182216.html
今日推荐