T2:
树上差分、前缀和、dfs序的应用。
只询问p,q路径,可以n^2.没必要树形DP。
直接n^2枚举点对O(1)算距离即可。
考虑枚举点对i,j,O(1)统计。
求出以i为lca的路径数和经过i且不以i为lca的路径数。
u=lca(i,j)。
w=u子树和,减去i~j链上,加上u子树外,减去经过u且不以u为lca条数。
先往简单想。别过于套路,一看路径长直接树DP。
好题。
B@哥 真巨。
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> using namespace std; inline void read(int &x) { x=0;char c=getchar(); while(c<'0'||c>'9') c=getchar(); while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-48,c=getchar(); } typedef long long ll; const int maxn=3005; int n,p,q,tot,dex,fr[maxn],first[maxn],dep[maxn],f[maxn]; ll ans,res,dp[maxn],dis[maxn],g[maxn],h[maxn]; struct Road{ int u,t,nxt; }eage[maxn<<1]; void add(int x,int y) {eage[++tot]=(Road){x,y,first[x]};first[x]=tot;} struct ST_Table{ int p,f[maxn<<1][22],len[maxn<<1],bin[22]; void init() { p=log(dex)/log(2)+1; for(int i=bin[0]=1;i<=p;++i) bin[i]=bin[i-1]<<1; for(int i=1;i<=dex;++i) len[i]=log(i)/log(2); for(int j=1;j<=p;++j) for(int i=1;i+bin[j]-1<=dex;++i) { if(dep[f[i][j-1]]<dep[f[i+bin[j-1]][j-1]]) f[i][j]=f[i][j-1]; else f[i][j]=f[i+bin[j-1]][j-1]; } } int LCA(int x,int y) { if(fr[x]>fr[y]) swap(x,y); int l=len[fr[y]-fr[x]+1]; if(dep[f[fr[x]][l]]<dep[f[fr[y]-bin[l]+1][l]]) return f[fr[x]][l]; else return f[fr[y]-bin[l]+1][l]; } }ST; ll Get_dep(int x,int y) {return dep[x]+dep[y]-2*dep[ST.LCA(x,y)];} void frdfs(int x,int fa) { ST.f[fr[x]=++dex][0]=x;dep[x]=dep[fa]+1;f[x]=fa; for(int i=first[x];i;i=eage[i].nxt) if(eage[i].t!=fa) { frdfs(eage[i].t,x); ST.f[++dex][0]=x; } } void redfs(int x,int fa) { for(int i=first[x];i;i=eage[i].nxt) if(eage[i].t!=fa) { redfs(eage[i].t,x); dp[x]+=dp[eage[i].t]; } } void lsdfs(int x,int fa) { h[x]=h[fa]+g[x]; for(int i=first[x];i;i=eage[i].nxt) if(eage[i].t!=fa) { lsdfs(eage[i].t,x); g[x]+=g[eage[i].t]; } } int main() { read(n);read(p);read(q); for(int i=1,x,y;i<=n-1;++i) { read(x);read(y); add(x,y);add(y,x); } frdfs(1,0); ST.init(); for(int i=1;i<=n;++i) for(int j=1,lca;j<=n;++j) if(Get_dep(i,j)==p) { lca=ST.LCA(i,j);++g[lca]; ++dp[i];++dp[j];--dp[lca];--dp[lca]; } redfs(1,0);lsdfs(1,0); dp[0]=dp[1];g[0]=g[1]; for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) if(Get_dep(i,j)==q) { int lca=ST.LCA(i,j); ans+=g[lca]-(h[i]+h[j]-h[lca]-h[f[lca]]); ans+=g[1]-g[lca]-dp[lca]; /** 经过该点且不以该点为lca的边数。*/ } printf("%lld\n",ans); return 0; }
我的n^2递推,利用dfs序预处理lca.
下面是 B&哥 思路:g[i][j],表示以dfs序上ij位置两点为端点路径数。
用来计算u子树外的部分:由于dfs序是连续段,则i,j均不在u子树dfs序范围内即可。
g数组的处理有个细节调了半天,注释上了。
#include<bits/stdc++.h> #define F(i,a,b) for(int i=a;i<=b;++i) #define LL long long #define pf(a) printf("%d ",a) #define PF(a) printf("%lld ",a) #define phn puts("") using namespace std; int read(); #define N 3005 int n,p,q; int to[N<<1],fir[N<<1],head[N],cnt; void add(int x,int y){to[++cnt]=y;fir[cnt]=head[x];head[x]=cnt;} int fa[N],dep[N],L[N],dfn[N],R[N],tot; void dfs(int x,int Fa){ dep[x]=dep[Fa]+1;fa[x]=Fa;L[x]=++tot;dfn[tot]=x; for(int i=head[x],v;i;i=fir[i])if(!dep[v=to[i]]){ dfs(v,x); } R[x]=tot; } int lca[N][N]; LL w[N],g[N][N],sub[N],up[N],ans; void push(int x){ sub[x]=w[x];up[x]=up[fa[x]]+w[x]; for(int i=head[x],v;i;i=fir[i])if((v=to[i])^fa[x]){ push(v);sub[x]+=sub[v]; } } /*g++ 2.cpp ./a.out */ void cal(int x,int y){ int u=lca[x][y]; ans+=sub[u]; ans-=up[x]+up[y]-up[u]-up[fa[u]]; int l=L[u],r=R[u]; ans+=g[l-1][l-1]; ans+=g[l-1][n]-g[l-1][r]; ans+=g[n][l-1]-g[r][l-1]; ans+=g[n][n]-g[n][r]-g[r][n]+g[r][r]; // if(x==3&&y==5){PF(ans);} } int main(){ freopen("b.in","r",stdin); // freopen("2.in","r",stdin);//freopen("2.out","w",stdout); n=read();p=read();q=read(); for(int i=1,u,v;i<n;++i){ u=read();v=read();add(u,v);add(v,u); } dfs(1,0); F(i,1,n){ lca[i][i]=i; F(j,L[i]+1,R[i])lca[i][dfn[j]]=lca[dfn[j]][i]=i; for(int j=L[i]+1;j<=R[i];j=R[dfn[j]]+1){ for(int k=j;k<=R[dfn[j]];++k){ for(int h=R[dfn[j]]+1;h<=R[i];++h){ lca[dfn[k]][dfn[h]]=lca[dfn[h]][dfn[k]]=i; } } } } F(i,1,n){ F(j,1,n){ if(dep[i]+dep[j]-dep[lca[i][j]]*2==q){ ++w[lca[i][j]];++g[L[i]][L[j]]; /** 注意:这里是g[L[i]],加到dfs序对应位置。*/ } } } F(i,1,n){ F(j,1,n){ g[i][j]=g[i][j-1]+g[i-1][j]+g[i][j]-g[i-1][j-1]; } } // F(i,1,n)pf(dfn[i]);phn;phn; // F(i,1,n){F(j,1,n)PF(g[i][j]);phn;}phn; push(1); F(i,1,n){ F(j,1,n){ if(dep[i]+dep[j]-dep[lca[i][j]]*2==p){ cal(i,j); // pf(i);pf(j);pf(ans);phn; } } } printf("%lld\n",ans); } int read(){ int s=0,f=0;char ch=getchar(); while(!isdigit(ch))f=ch=='-',ch=getchar(); while(isdigit(ch))s=s*10+(ch^48),ch=getchar(); return f?-s:s; } /* g++ d2.cpp ./a.out g++ 2.cpp ./a.out g++ bsgs.cpp ./a.out 10 2 3 2 1 3 1 4 1 5 3 6 3 7 6 8 6 9 6 10 7 g++ 2.cpp ./a.out 8 2 2 1 2 1 3 2 4 2 5 4 6 4 7 5 8 g++ 2.cpp ./a.out 5 2 1 1 2 2 3 3 4 2 5 g++ 2.cpp ./a.out 4 1 1 1 2 2 3 3 4 */