POJ3728
题意:从a到b的路径上,找权值差最大的两点,并且后一点比前一点大。输出最大权值差。
题解
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
using namespace std;
int const inf = 0x7f7f7f7f;
int const N = 50000 + 10;
int const M = 20;
int first[N],ne[N<<1],to[N<<1],lg[N];
int tot,n,q;
int fa[N][M],dep[N],val[N],up[N][M],down[N][M],maxx[N][M],minn[N][M];
void pre(){
lg[1] = 0;
for(int i=2;i<N;i++)
lg[i] = lg[i/2] + 1;
}
void add(int u,int v){
ne[tot] = first[u];
to[tot] = v;
first[u] = tot++;
}
void dfs(int u,int f){
dep[u] = dep[f] + 1; fa[u][0] = f;
maxx[u][0] = max(val[u],val[f]); minn[u][0] = min(val[u],val[f]);
up[u][0] = max(0,val[f] - val[u]); down[u][0] = max(0,val[u] - val[f]);
for(int i=1;(1<<i)<=dep[u];i++){
int k = fa[u][i-1];
fa[u][i] = fa[k][i-1];
maxx[u][i] = max(maxx[u][i-1],maxx[k][i-1]);
minn[u][i] = min(minn[u][i-1],minn[k][i-1]);
up[u][i] = max(max(0,maxx[k][i-1]-minn[u][i-1]),max(up[u][i-1],up[k][i-1]));
down[u][i] = max(max(0,maxx[u][i-1]-minn[k][i-1]),max(down[u][i-1],down[k][i-1]));
}
for(int i=first[u];~i;i=ne[i]){
int v = to[i];
if(v == f) continue;
dfs(v,u);
}
}
int lca(int x,int y){
if(dep[x] < dep[y]) swap(x,y);
while(dep[x] > dep[y])
x = fa[x][lg[dep[x]-dep[y]]];
if(x == y) return x;
for(int i=lg[dep[x]];i>=0;i--)
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
void solve(){
scanf("%d",&q);
while(q--){
int u,v;
scanf("%d%d",&u,&v);
int t = lca(u,v);
int MAX = 0,MIN = inf,upmax = 0,downmax = 0;
for(int depth=18;depth>=0;depth--){ //深度由许多个二进制组成
if((dep[u]-dep[t])&(1<<depth)){
upmax = max(upmax,up[u][depth]);
upmax = max(upmax,maxx[u][depth]-MIN);
MIN = min(MIN,minn[u][depth]);
u = fa[u][depth];
}
}
for(int depth=18;depth>=0;depth--){
if((dep[v]-dep[t])&(1<<depth)){
int depth = lg[dep[v] - dep[t]];
downmax = max(downmax,down[v][depth]);
downmax = max(downmax,MAX - minn[v][depth]);
MAX = max(MAX,maxx[v][depth]);
v = fa[v][depth];
}
}
printf("%d\n",max(max(0,MAX-MIN),max(upmax,downmax)));
}
}
int main(){
pre();
memset(first,-1,sizeof(first));
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int i=1;i<=n-1;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
solve();
return 0;
}