3717. 【NOI2014模拟7.2】火车(train)

Description

A国有n个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一路径。现在有火车在城市a,需要经过m个城市。火车按照以下规则行驶:每次行驶到还没有经过的城市中在m个城市中最靠前的。现在小A想知道火车经过这m个城市后所经过的道路数量。

Input

第一行三个整数n、m、a,表示城市数量、需要经过的城市数量,火车开始时所在位置。

接下来n-1行,每行两个整数x和y,表示x和y之间有一条双向道路。

接下来一行m个整数,表示需要经过的城市。

Output

一行一个整数,表示火车经过的道路数量。

Sample Input

5 4 2

1 2

2 3

3 4

4 5

4 3 1 5

Sample Output

9

Data Constraint

这里写图片描述

Solution

首先要注意本题的一个大坑,题目说每次行驶到还没有经过的城市中在m个城市中最靠前的。也就是说如果在走第一个点的时候经过了第二个点,那么第二个点就不用再走了。
考虑用动态树维护所有点在序列中的位置,未出现则为无穷大。
每次找到下一个没有被访问的点,不断找链上最小值,标记为已访问并把权值改为无穷大。
但实际上可以直接用倍增lca,然后再用并查集维护这个点是否走过。这个操作其实很简单,首先我们将所有的节点的并查集父亲赋成这棵树的父亲,然后再将必须要经过的点的并查集的父亲赋成它自己,因为这样做的时候,我们每次找并查集的父亲的时候如果找到了一个必须要走的点就必须退出找父亲,那么就可以直接将它标记了,至于那些我们不需要经过的点就可以不用管了。

Code1

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define ll long long
#define N 500010
using namespace std;
int n,m,x,y,now,a[N],f[N][20],fa[N],lca;
int t[N*2],nx[N*2],ls[N],d[N],s[N];
bool bz[N];
ll ans=0;
void add(int x,int y){
    t[++t[0]]=y;
    nx[t[0]]=ls[x];
    ls[x]=t[0];
}
int getlca(int x,int y){
    if(d[y]>d[x]) swap(x,y);
    int dep=d[x]-d[y];
    for(int i=19;i>=0;i--) 
        if(dep&(1<<i)) x=f[x][i];
    if(x==y) return x;
    for(int i=19;i>=0;i--)
        if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
int get(int x){
    if(fa[x]==x) return x;
    return fa[x]=get(fa[x]);
}
void update(int x,int y){
    while(d[x]>d[y]){
        bz[x]=1;
        fa[x]=f[x][0];
        x=get(x);
    }
}
int main(){
    freopen("train.in","r",stdin);
    freopen("train.out","w",stdout);
    scanf("%d%d%d",&n,&m,&now);
    for(int i=1;i<=n-1;i++){
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    s[1]=d[1]=1;
    int i=0,j=1;
    while(i++<j){
        int x=s[i];
        for(int k=ls[x];k;k=nx[k]){
            if(t[k]!=f[x][0]){
                fa[t[k]]=f[t[k]][0]=x;
                d[t[k]]=d[x]+1;
                s[++j]=t[k];
            }
        }
    }
    for(int i=1;i<=m;i++){
        scanf("%d",&a[i]);
        fa[a[i]]=a[i];
    } 
    for(int i=1;i<=19;i++){
        for(int j=1;j<=n;j++) f[j][i]=f[f[j][i-1]][i-1];
    }
    memset(bz,0,sizeof(bz));
    for(int i=1;i<=m;i++){
        x=a[i];
        if(bz[x]) continue;
        lca=getlca(now,x);
        bz[x]=bz[now]=bz[lca]=1;
        ans+=d[x]+d[now]-2*d[lca];
        update(x,lca);
        update(now,lca);
        now=x;
    }
    printf("%lld\n",ans);
    return 0;
}

Code2

var
        n,m,s,i,j,x,y,tot,fx,fy:longint;
        ans:int64;
        last,deep,b:array[0..500000] of longint;
        f:array[0..500000,0..20] of longint;
        a:array[1..1000000,1..3] of longint;
        g:array[0..500000] of longint;
procedure link(x,y:longint);
begin
        inc(tot); a[tot,1]:=last[x]; a[tot,2]:=y; last[x]:=tot;
end;
procedure build(x:longint);
var
        i,y:longint;
begin
        i:=last[x];
        while i>0 do
        begin
                y:=a[i,2];
                if y<>f[x,0] then
                begin
                        f[y,0]:=x; deep[y]:=deep[x]+1;
                        build(y);
                end;
                i:=a[i,1];
        end;
end;
function lca(x,y:longint):longint;
var
        i:longint;
begin
        for i:=20 downto 0 do
        begin
                while deep[f[x,i]]>=deep[y] do x:=f[x,i];
        end;
        for i:=20 downto 0 do
        begin
                while deep[x]<=deep[f[y,i]] do y:=f[y,i];
        end;
        if x=y then exit(x);
        for i:=20 downto 0 do
        begin
                while f[x,i]<>f[y,i] do
                begin
                        x:=f[x,i]; y:=f[y,i];
                end;
        end;
        exit(f[x,0]);
end;

function find(x:longint):longint;
begin
        if g[x]=x then exit(x);
        find:=find(g[x]); g[x]:=find;
end;
begin
        assign(input,'train.in');
        assign(output,'train.out');
        reset(input);
        rewrite(output);
        read(n,m,s);
        for i:=1 to n-1 do
        begin
                read(x,y); link(x,y); link(y,x);
        end;
        deep[1]:=1;
        build(1);
        for j:=1 to 20 do
                for i:=1 to n do
                        f[i,j]:=f[f[i,j-1],j-1];
        for i:=1 to n do g[i]:=i;

        for i:=1 to m do
        begin
                read(y);
                if g[y]<>y then continue;
                x:=lca(s,y);
                ans:=ans+deep[s]+deep[y]-deep[x]*2;
                fx:=find(s); fy:=find(y);
                while fx<>fy do
                begin
                        if deep[fx]<deep[fy] then
                        begin
                                tot:=fx; fx:=fy; fy:=tot;
                        end;
                        g[fx]:=find(f[fx,0]);
                        fx:=g[fx];
                end;
                if find(x)=x then g[x]:=f[x,0];
                s:=y;
        end;
        write(ans);
        close(input);
        close(output);
end.

作者:zsjzliziyang
QQ:1634151125
转载及修改请注明
本文地址:https://blog.csdn.net/zsjzliziyang/article/details/81843269

猜你喜欢

转载自blog.csdn.net/zsjzliziyang/article/details/81843269