BZOJ5072 小A的树(树形dp)

  容易猜到能选择的黑点个数是一个连续区间。那么设f[i][j]为i子树内选j个点形成包含根的连通块,最多有几个黑点,g[i][j]为最少有几个黑点,暴力dp是O(n2)的,求出每个连通块大小对应的黑点数量取值范围即可。

  惊觉差点不会树形背包了。注意不要出现任何非法转移,即使看上去无伤大雅。

#include<iostream> 
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define N 5010
char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;}
int gcd(int n,int m){return m==0?n:gcd(m,n%m);}
int read()
{
    int x=0,f=1;char c=getchar();
    while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
    while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return x*f;
}
int T,n,m,a[N],p[N],size[N],f[2][N][N],l[N],r[N],t;
struct data{int to,nxt;
}edge[N<<1];
void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;}
void dfs(int k,int from)
{
    int s=1;size[k]=1;
    for (int i=p[k];i;i=edge[i].nxt)
    if (edge[i].to!=from) dfs(edge[i].to,k),s+=size[edge[i].to];
    for (int i=0;i<=s;i++) f[0][k][i]=f[1][k][i]=0;
    f[0][k][1]=a[k],f[1][k][1]=a[k]^1;
    for (int i=p[k];i;i=edge[i].nxt)
    if (edge[i].to!=from)
    {
        size[k]+=size[edge[i].to];
        for (int j=size[k];j>=1;j--)
            for (int x=max(1,j-size[edge[i].to]);x<=min(size[k]-size[edge[i].to],j);x++)
            f[0][k][j]=max(f[0][k][j],f[0][k][x]+f[0][edge[i].to][j-x]),
            f[1][k][j]=max(f[1][k][j],f[1][k][x]+f[1][edge[i].to][j-x]);
    }
    for (int i=1;i<=size[k];i++) l[i]=min(l[i],i-f[1][k][i]);
    for (int i=1;i<=size[k];i++) r[i]=max(r[i],f[0][k][i]);
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("bzoj5072.in","r",stdin);
    freopen("bzoj5072.out","w",stdout);
    const char LL[]="%I64d\n";
#else
    const char LL[]="%lld\n";
#endif
    T=read();
    while (T--)
    {
        n=read(),m=read();
        memset(p,0,sizeof(p));t=0;
        for (int i=1;i<n;i++)
        {
            int x=read(),y=read();
            addedge(x,y),addedge(y,x);
        }
        for (int i=1;i<=n;i++) a[i]=read(),l[i]=n+1,r[i]=0;
        dfs(1,1);
        while (m--)
        {
            int x=read(),y=read();
            if (l[x]<=y&&r[x]>=y) puts("YES");
            else puts("NO");
        }
        cout<<endl;
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Gloid/p/10060629.html
今日推荐