newker训练营4

T2:

树上差分、前缀和、dfs序的应用。

只询问p,q路径,可以n^2.没必要树形DP。

直接n^2枚举点对O(1)算距离即可。

考虑枚举点对i,j,O(1)统计。

求出以i为lca的路径数和经过i且不以i为lca的路径数。

u=lca(i,j)。

w=u子树和,减去i~j链上,加上u子树外,减去经过u且不以u为lca条数。

先往简单想。别过于套路,一看路径长直接树DP。

好题。

B@哥 真巨。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
inline void read(int &x)
{
    x=0;char c=getchar();
    while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-48,c=getchar();
}
typedef long long ll;
const int maxn=3005;
int n,p,q,tot,dex,fr[maxn],first[maxn],dep[maxn],f[maxn];
ll ans,res,dp[maxn],dis[maxn],g[maxn],h[maxn];
struct Road{
    int u,t,nxt;
}eage[maxn<<1];
void add(int x,int y) {eage[++tot]=(Road){x,y,first[x]};first[x]=tot;}
struct ST_Table{
    int p,f[maxn<<1][22],len[maxn<<1],bin[22];
    void init()
    {
        p=log(dex)/log(2)+1;
        for(int i=bin[0]=1;i<=p;++i) bin[i]=bin[i-1]<<1;
        for(int i=1;i<=dex;++i) len[i]=log(i)/log(2);
        for(int j=1;j<=p;++j)
            for(int i=1;i+bin[j]-1<=dex;++i)
            {
                if(dep[f[i][j-1]]<dep[f[i+bin[j-1]][j-1]]) f[i][j]=f[i][j-1];
                else f[i][j]=f[i+bin[j-1]][j-1];
            }
    }
    int LCA(int x,int y)
    {
        if(fr[x]>fr[y]) swap(x,y);
        int l=len[fr[y]-fr[x]+1];
        if(dep[f[fr[x]][l]]<dep[f[fr[y]-bin[l]+1][l]]) return f[fr[x]][l];
        else return f[fr[y]-bin[l]+1][l];
    }
}ST;
ll Get_dep(int x,int y) {return dep[x]+dep[y]-2*dep[ST.LCA(x,y)];}
void frdfs(int x,int fa)
{
    ST.f[fr[x]=++dex][0]=x;dep[x]=dep[fa]+1;f[x]=fa;
    for(int i=first[x];i;i=eage[i].nxt)
        if(eage[i].t!=fa)
        {
            frdfs(eage[i].t,x);
            ST.f[++dex][0]=x;
        }
}
void redfs(int x,int fa)
{
    for(int i=first[x];i;i=eage[i].nxt)
        if(eage[i].t!=fa)
        {
            redfs(eage[i].t,x);
            dp[x]+=dp[eage[i].t];
        }
}
void lsdfs(int x,int fa)
{
    h[x]=h[fa]+g[x];
    for(int i=first[x];i;i=eage[i].nxt)
        if(eage[i].t!=fa)
        {
            lsdfs(eage[i].t,x);
            g[x]+=g[eage[i].t];
        }
}
int main()
{
    read(n);read(p);read(q);
    for(int i=1,x,y;i<=n-1;++i)
    {
        read(x);read(y);
        add(x,y);add(y,x);
    }
    frdfs(1,0);
    ST.init();
    for(int i=1;i<=n;++i)
        for(int j=1,lca;j<=n;++j)
            if(Get_dep(i,j)==p)
            {
                lca=ST.LCA(i,j);++g[lca];
                ++dp[i];++dp[j];--dp[lca];--dp[lca];
            }
    redfs(1,0);lsdfs(1,0);
    dp[0]=dp[1];g[0]=g[1];
    for(int i=1;i<=n;++i)
        for(int j=1;j<=n;++j)
            if(Get_dep(i,j)==q)
            {
                int lca=ST.LCA(i,j);
                ans+=g[lca]-(h[i]+h[j]-h[lca]-h[f[lca]]);
                ans+=g[1]-g[lca]-dp[lca];
                /** 经过该点且不以该点为lca的边数。*/
            }
    printf("%lld\n",ans);
    return 0;
}
lnc代码

我的n^2递推,利用dfs序预处理lca.

下面是 B&哥 思路:g[i][j],表示以dfs序上ij位置两点为端点路径数。

用来计算u子树外的部分:由于dfs序是连续段,则i,j均不在u子树dfs序范围内即可。

g数组的处理有个细节调了半天,注释上了。

#include<bits/stdc++.h>
#define F(i,a,b) for(int i=a;i<=b;++i)
#define LL long long
#define pf(a) printf("%d ",a)
#define PF(a) printf("%lld ",a)
#define phn puts("")
using namespace std;
int read();
#define N 3005
int n,p,q;
int to[N<<1],fir[N<<1],head[N],cnt;
void add(int x,int y){to[++cnt]=y;fir[cnt]=head[x];head[x]=cnt;}
int fa[N],dep[N],L[N],dfn[N],R[N],tot;
void dfs(int x,int Fa){
    dep[x]=dep[Fa]+1;fa[x]=Fa;L[x]=++tot;dfn[tot]=x;
    for(int i=head[x],v;i;i=fir[i])if(!dep[v=to[i]]){
        dfs(v,x);
    }
    R[x]=tot;
}
int lca[N][N];
LL w[N],g[N][N],sub[N],up[N],ans;
void push(int x){
    sub[x]=w[x];up[x]=up[fa[x]]+w[x];
    for(int i=head[x],v;i;i=fir[i])if((v=to[i])^fa[x]){
        push(v);sub[x]+=sub[v];
    }
}
/*g++ 2.cpp
./a.out
*/
void cal(int x,int y){
    int u=lca[x][y];
    ans+=sub[u];
    ans-=up[x]+up[y]-up[u]-up[fa[u]];
    int l=L[u],r=R[u];
    ans+=g[l-1][l-1];
    ans+=g[l-1][n]-g[l-1][r];
    ans+=g[n][l-1]-g[r][l-1];
    ans+=g[n][n]-g[n][r]-g[r][n]+g[r][r];
//    if(x==3&&y==5){PF(ans);}
}
int main(){
   freopen("b.in","r",stdin);
//    freopen("2.in","r",stdin);//freopen("2.out","w",stdout);
    n=read();p=read();q=read();
    for(int i=1,u,v;i<n;++i){
        u=read();v=read();add(u,v);add(v,u);
    } 
    dfs(1,0);
    F(i,1,n){
        lca[i][i]=i;
        F(j,L[i]+1,R[i])lca[i][dfn[j]]=lca[dfn[j]][i]=i;
        for(int j=L[i]+1;j<=R[i];j=R[dfn[j]]+1){
            for(int k=j;k<=R[dfn[j]];++k){
                for(int h=R[dfn[j]]+1;h<=R[i];++h){
                    lca[dfn[k]][dfn[h]]=lca[dfn[h]][dfn[k]]=i;
                }
            }
        }
    }
    F(i,1,n){
        F(j,1,n){
            if(dep[i]+dep[j]-dep[lca[i][j]]*2==q){
                ++w[lca[i][j]];++g[L[i]][L[j]];
                /** 注意:这里是g[L[i]],加到dfs序对应位置。*/
            }
        }
    }
    F(i,1,n){
        F(j,1,n){
            g[i][j]=g[i][j-1]+g[i-1][j]+g[i][j]-g[i-1][j-1];
        }
    }
 //   F(i,1,n)pf(dfn[i]);phn;phn;
  //  F(i,1,n){F(j,1,n)PF(g[i][j]);phn;}phn;
    push(1);
    F(i,1,n){
        F(j,1,n){
            if(dep[i]+dep[j]-dep[lca[i][j]]*2==p){
                cal(i,j);
  //              pf(i);pf(j);pf(ans);phn;
            }
        }
    }
    printf("%lld\n",ans);
}
int read(){
    int s=0,f=0;char ch=getchar();
    while(!isdigit(ch))f=ch=='-',ch=getchar();
    while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
    return f?-s:s;
}
/*
g++ d2.cpp
./a.out
g++ 2.cpp
./a.out
g++ bsgs.cpp
./a.out
10 2 3 
2 1 
3 1 
4 1 
5 3 
6 3 
7 6 
8 6 
9 6 
10 7
g++ 2.cpp
./a.out
8 2 2
1 2
1 3
2 4
2 5
4 6
4 7
5 8
g++ 2.cpp
./a.out
5 2 1
1 2
2 3
3 4
2 5
g++ 2.cpp
./a.out
4 1 1
1 2
2 3
3 4
*/
View Code

猜你喜欢

转载自www.cnblogs.com/seamtn/p/11803304.html