NOIp2018 D2T3 defense——树上倍增

题目:https://www.luogu.org/problemnew/show/P5024

考场上只会写n,m<=2000的暴力,还想了想A1和A2的情况,不过好像只得了A1的分。然后仔细一看,原来是把dp2[ ][ ]写成dp[ ][ ]了。改一下,就能得到A1和A2的分。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5;
const ll INF=1e10+5;
int n,m,p[N],hd[N],xnt,to[N<<1],nxt[N<<1];
int q0,q1,f0,f1;
ll dp[N][3],dp2[N][3];
char ch[5];
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9') ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
void add(int x,int y)
{
  to[++xnt]=y; nxt[xnt]=hd[x]; hd[x]=xnt;
}
ll Mn(ll a,ll b){return a<b?a:b;}
ll Mx(ll a,ll b){return a>b?a:b;}
void dfs(int cr,int fa)
{
  dp[cr][0]=0; dp[cr][1]=p[cr];
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    dfs(v,cr);
    dp[cr][0]+=dp[v][1];
    dp[cr][1]+=Mn(dp[v][0],dp[v][1]);
      }
  if(cr==q0)dp[cr][!f0]=INF;
  if(cr==q1)dp[cr][!f1]=INF;
}
bool chk()
{
  bool flag=0;
  for(int i=hd[q0];i;i=nxt[i])
    if(to[i]==q1){flag=1;break;}
  if(flag&&!f0&&!f1)
    {
      puts("-1");return true;
    }
  return false;
}
void solve1()
{
  for(int i=1;i<=m;i++)
    {
      q0=rdn();f0=rdn();q1=rdn();f1=rdn();
      if(chk())continue;
      dfs(1,0);
      printf("%lld\n",Mn(dp[1][0],dp[1][1]));
    }
}
void solve2()
{
  dp[1][1]=p[1]; dp[1][0]=INF;
  for(int i=2;i<=n;i++)
    {
      dp[i][1]=Mn(dp[i-1][0],dp[i-1][1])+p[i];
      dp[i][0]=dp[i-1][1];
    }
  dp2[n][1]=p[n]; dp2[n][0]=0;
  for(int i=n-1;i;i--)
    {
      dp2[i][1]=Mn(dp2[i+1][0],dp2[i+1][1])+p[i];
      dp2[i][0]=dp2[i+1][1];
    }
  for(int i=1;i<=m;i++)
    {
      q0=rdn();f0=rdn();q1=rdn();f1=rdn();
      if(chk())continue;
      printf("%lld\n",dp[q1][f1]+dp2[q1][f1]-(f1?p[q1]:0));
    }
}
void solve3()
{
  dp[1][1]=p[1]; dp[1][0]=0;
  for(int i=2;i<=n;i++)
    {
      dp[i][1]=Mn(dp[i-1][0],dp[i-1][1])+p[i];
      dp[i][0]=dp[i-1][1];
    }
  dp2[n][1]=p[n]; dp2[n][0]=0;
  for(int i=n-1;i;i--)
    {
      dp[i][1]=Mn(dp[i+1][0],dp[i+1][1])+p[i];
      dp[i][0]=dp[i+1][1];
    }
  for(int i=1;i<=m;i++)
    {
      q0=rdn();f0=rdn();q1=rdn();f1=rdn();
      if(chk())continue;
      if(q0>q1)swap(q0,q1),swap(f0,f1);
      printf("%lld\n",dp[q0][f0]+dp2[q1][f1]);
    }
}
int main()
{
  freopen("defense.in","r",stdin);
  freopen("defense.out","w",stdout);
  n=rdn();m=rdn();scanf("%s",ch+1);
  for(int i=1;i<=n;i++)p[i]=rdn();
  for(int i=1,u,v;i<n;i++)
    {
      u=rdn(); v=rdn(); add(u,v); add(v,u);
    }
  if(n<=2000)solve1();
  else if(ch[1]=='A'&&ch[2]=='1')solve2();
  else if(ch[1]=='A'&&ch[2]=='2')solve3();
  else solve1();
  return 0;
}
View Code

然后得知A的分好像就是一个线段树。想一想,就记录一下该区间两端的是0还是1就行了。

正解的一种是倍增。

  先做出正常的dp[ ][ 0/1 ]数组。考虑倍增,f[ cr ][ i ][0/1][0/1]表示自己到第 i 个祖先的路上的贡献(不含自己及自己子树,含祖先,含路上的点以及它们的分叉,不含祖先上面的部分);只要把dp数组累加起来就行了;累加的时候注意把自己这一条减去,就是如果用父亲的dp[ ][1]的话,就减去自己的min(dp[ ][0],dp[ ][1]),不然就减去自己的dp[ ][1],然后把父亲的这个减去之后的东西放到f[ ][0][ ][ ]里;i 的其他值正常倍增合并就行了。

  考虑统计,两个端点就用它们的dp值就行;路上的就倍增地走,用 f 值就行;lca处需要一个以该点为根的值,减去走来的那两条,然后累加到答案里。因为在lca处的那两条一定是原树的两个子树,所以可以换一遍根记录每个点作为根的值,用的时候减去那两条的dp值即可。

  思路似乎也很简单?倍增真是好物。考虑两个点的问题时应该想一想倍增、lca之类的。然后考虑倍增数组维护什么,想一想统计答案的时候分为几种不同的部分,应该就差不多了。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5,M=20; const ll INF=1e10+5;
int n,m,p[N],pr[N][M],hd[N],xnt,to[N<<1],nxt[N<<1],lm,dep[N];
ll dp[N][2],info[N][2],f[N][M][2][2];
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9') ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
ll Mn(ll a,ll b){return a<b?a:b;}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void dfs1(int cr,int fa)
{
  dp[cr][1]=p[cr];  dep[cr]=dep[fa]+1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    dfs1(v,cr);
    dp[cr][1]+=Mn(dp[v][0],dp[v][1]);
    dp[cr][0]+=dp[v][1];
      }
}
void dfs2(int cr,int fa,ll w0,ll w1)
{
  info[cr][0]=dp[cr][0]+w1;
  info[cr][1]=dp[cr][1]+Mn(w0,w1);
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    dfs2(v,cr,info[cr][0]-dp[v][1],info[cr][1]-Mn(dp[v][0],dp[v][1]));
      }
}
void dfsx(int cr,int fa)
{
  for(int i=1,d;i<=lm&&pr[pr[cr][i-1]][i-1];i++)
    {
      d=pr[cr][i-1];
      pr[cr][i]=pr[d][i-1];
      for(int j=0;j<=1;j++)
    for(int k=0;k<=1;k++)
      {
        f[cr][i][j][k]=Mn(f[cr][i-1][j][0]+f[d][i-1][0][k],f[cr][i-1][j][1]+f[d][i-1][1][k]);//not dec p[d]
      }
    }
  for(int i=hd[cr],v;i;i=nxt[i])//f:without son
    if((v=to[i])!=fa)
      {
    pr[v][0]=cr;
    for(int j=0;j<=1;j++)
      {
        f[v][0][j][0]=dp[cr][0]-dp[v][1];//not pls p[v]
        f[v][0][j][1]=dp[cr][1]-Mn(dp[v][0],dp[v][1]);//not pls p[v]
      }
    f[v][0][0][0]=INF;
    dfsx(v,cr);
      }
}
ll cz(int x,int f0,int y,int f1)
{
  ll d00=0,d01=0,d10=0,d11=0;
  if(dep[x]<dep[y])swap(x,y),swap(f0,f1);
  int x0=x,y0=y;
  ll y00=0,y01=0;
  if(f0)y00=INF; else y01=INF;
  for(int i=lm;i>=0;i--)
    if(dep[pr[x][i]]>dep[y])//> not >=
      {
    d00=Mn(y00+f[x][i][0][0],y01+f[x][i][1][0]);
    d01=Mn(y00+f[x][i][0][1],y01+f[x][i][1][1]);
    y00=d00; y01=d01;
    x=pr[x][i];
      }
  ll ret=0;
  if(pr[x][0]==y)
    {
      ret=info[y][f1];
      if(f1)  ret-=Mn(dp[x][0],dp[x][1]),ret+=Mn(y00,y01);//y** not d** for init
      else  ret-=dp[x][1],ret+=y01;
      ret+=dp[x0][f0];
      return ret;
    }
  if(dep[x]!=dep[y])
    {
      d00=Mn(y00+f[x][0][0][0],y01+f[x][0][1][0]);
      d01=Mn(y00+f[x][0][0][1],y01+f[x][0][1][1]);
      y00=d00; y01=d01;
      x=pr[x][0];
    }

  ll y10=0,y11=0;
  if(f1)y10=INF; else y11=INF;
  for(int i=lm;i>=0;i--)
    if(pr[x][i]!=pr[y][i])
      {
    d00=Mn(y00+f[x][i][0][0],y01+f[x][i][1][0]);
    d01=Mn(y00+f[x][i][0][1],y01+f[x][i][1][1]);
    y00=d00; y01=d01; x=pr[x][i];
    d10=Mn(y10+f[y][i][0][0],y11+f[y][i][1][0]);
    d11=Mn(y10+f[y][i][0][1],y11+f[y][i][1][1]);
    y10=d10; y11=d11; y=pr[y][i];
      }
  int d=pr[x][0];
  ret=info[d][1]-Mn(dp[x][0],dp[x][1])-Mn(dp[y][0],dp[y][1])+Mn(y00,y01)+Mn(y10,y11);//y** not d** for init
  ll rt2=info[d][0]-dp[x][1]-dp[y][1]+y01+y11;
  ret=Mn(ret,rt2)+dp[x0][f0]+dp[y0][f1];
  return ret;
}
int main()
{
  freopen("defense.in","r",stdin);
  freopen("defense.out","w",stdout);
  char ch[5];
  n=rdn(); m=rdn(); scanf("%s",ch);
  for(;(1<<lm)<n;lm++);
  for(int i=1;i<=n;i++)p[i]=rdn();
  for(int i=1,u,v;i<n;i++)
    {
      u=rdn(); v=rdn(); add(u,v); add(v,u);
    }
  dfs1(1,0);dfs2(1,0,0,0);dfsx(1,0);
  for(int i=1,a,b,x,y;i<=m;i++)
    {
      a=rdn();x=rdn();b=rdn();y=rdn();
      ll ans=cz(a,x,b,y);
      printf("%lld\n",ans>=INF?-1:ans);
    }
  return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Narh/p/9999483.html