AtCoder Regular Contest 097F: Monochrome Cat 题解

这道题的思路非常巧妙
首先我们可以把那些全部都是黑点的子树扔掉,他们是不会被遍历的
这时,剩下的树的所有的叶子都是白色的
我们先考虑终点和起点重合的情况
这时如果我们想把它全部变黑的话必须遍历整棵树
并且我们发现,无论以哪个点为起点,答案都是一样的,首先所有的边都要走两遍,第二,一个点被走过的次数等于它的度数,所以一个点是否需要停留一秒也是可以算出来的(显然的是,一个点最多停留一秒),这样我们可以先预处理一个need数组,need[i]=1表示在i这个节点需要停留1s
这时我们得出了一个初始答案
我们再考虑终点离开起点的情况,来优化答案
我们发现当终点和起点不在一起时,从起点到终点的路径上所有的边只会走一次,还有这条路径上除了终点的其他点访问的次数都会-1,也就是说,如果原来这个节点需要停留1s,现在就不需要停留了,可以节省一秒,否则就要额外停留,时间增加1s
于是我们可以获得这样一个问题:树上的每个点都有一个数1或0(这个数就是need数组)我们要在树上找一条路径,最大化“点数+1的个数-0的个数”
我们找的这条路径是不包含终点的路径,所以点数就是边数,终点的访问次数不-1的情况也被考虑了,但需要额外注意的是,因为我最后需要将终点加到这条链的某一端,所以我找的这条链的两端不能同时是叶子
我们考虑树型dp
dp[i][0]表示以i为根的子树中,从某一个叶子出发到i的链,上述答案最大的是多少
dp[i][1]表示以i为根的子树中,不从叶子出发到i的链,上述答案最大是多少
dp2[i]表示以i为根的子树中,合法的所有链中,上述答案最大的是多少
转移非常简单,注意一下dp2[i]不能由“dp[son1][0]+dp[son2][0]+根的贡献”得到就好
还有几个细节
1. 我在搜索的时候,一定要以一个白点为根搜索,否则我的根那里可能是要被预先砍掉的
2. 理论上,只包含一个单独的叶子的路径是合法的,但要特判一下整棵树只有一个白点的情况


更新:感谢Curious_Cat_is_OIer的评论,事实上是不用考虑叶子的影响的,因为我所有的叶子都是白色的,而且叶子的相邻节点一定只有一个,所以它的need一定是0,所以选中一个叶子一定会使点数+1,然后多花1的代价,所以如果一条最大权值链里面包含叶子,把叶子去掉也是最大权值链,所以直接找最大权值链就可以了,不用像上面一样考虑奇怪的情况

#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pii pair<double,double>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;

const int MOD=1e9+7;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
const double pi=acos(-1);

inline int getint()
{
    char ch;int res;bool f;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

int n;
vector<int> v[100048];int col[100048],d[100048],need[100048];
//0: leaf reachable; 1: leaf unreachable
int dp[100048][2],dp2[100048];
char s[100048];
int haswhite[100048];
int ans,maxminus=0;
int cnt=0;

inline void dfs(int cur,int father)
{
    haswhite[cur]=col[cur];
    int i,y;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=v[cur][i];
        if (y!=father)
        {
            dfs(y,cur);
            if (haswhite[y]) haswhite[cur]=1;
        }
    }
}

inline void solve(int cur,int father)
{
    int i,y;bool isleaf=true;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=v[cur][i];
        if (y!=father && haswhite[y])
        {
            isleaf=false;
            solve(y,cur);
        }
    }
    if (isleaf)
    {
        dp[cur][0]=1+(need[cur]==1?1:-1);
        dp[cur][1]=-INF;
        dp2[cur]=dp[cur][0];
        if (cnt>1) maxminus=max(maxminus,dp2[cur]);
        return;
    }
    dp[cur][0]=-INF;dp[cur][1]=dp2[cur]=1+(need[cur]==1?1:-1);
    int nmax=-INF,ymax=-INF;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=v[cur][i];
        if (y!=father && haswhite[y])
        {
            dp[cur][0]=max(dp[cur][0],dp[y][0]+1+(need[cur]==1?1:-1));
            dp[cur][1]=max(dp[cur][1],dp[y][1]+1+(need[cur]==1?1:-1));
            dp2[cur]=max(dp2[cur],nmax+max(dp[y][0],dp[y][1])+1+(need[cur]==1?1:-1));
            dp2[cur]=max(dp2[cur],ymax+dp[y][1]+1+(need[cur]==1?1:-1));
            nmax=max(nmax,dp[y][1]);ymax=max(ymax,dp[y][0]);
        }
    }
    maxminus=max(maxminus,dp2[cur]);
    maxminus=max(maxminus,max(dp[cur][0],dp[cur][1]));
}

int main ()
{
    int i,j,x,y;
    n=getint();
    for (i=1;i<=n-1;i++)
    {
        x=getint();y=getint();
        v[x].pb(y);v[y].pb(x);
    }
    scanf("%s",s+1);
    int root=-1;
    for (i=1;i<=n;i++) {col[i]=(s[i]=='W');if (col[i] && root==-1) root=i;}
    if (root==-1)
    {
        printf("0\n");
        return 0;
    }
    dfs(root,-1);ans=0;
    for (i=1;i<=n;i++)
        if (haswhite[i])
            for (j=0;j<int(v[i].size());j++)
                if (haswhite[v[i][j]]) ans++,d[i]++;
    for (i=1;i<=n;i++) if (haswhite[i]) cnt++,need[i]=(d[i]%2)^col[i],ans+=need[i];
    solve(root,-1);
    printf("%d\n",ans-maxminus);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/iceprincess_1968/article/details/80311467