[9.26模拟赛]T2

T2

Description

机房竟然被小偷光顾了!小偷竟然还光明正大地走出了校门!现在我们掌握了一些信息,希望你能帮我们抓住小偷。
假定该城市内有\(N\)个地点,并按\(1\)——\(N\)编号。任意两点都有且只有一条路径相连,每条边的长度均为\(1\)。现在已知小偷准备在事情平息后,到\(M\)个手机店去出售偷到的手机,这\(M\)个手机店的编号分别为\(X1\),\(X2\),\(X3\),\(X4\)……由于小偷很懒惰,他的躲藏之处到每个手机店的距离都不超过\(D\)
现在希望你能帮助我们求出小偷可能躲藏的地方有几个。(\(PS\):本题除了前两句之外纯属虚构)

Input

第一行\(3\)个整数\(N\),\(M\),\(D\)。接下来一行\(M\)个整数,分别表示手机店的编号。再接下来\(N-1\)行,每行\(2\)个整数\(X\),\(Y\)。表示\(X\)\(Y\)有一条无向边。

Output

一行表示可能地点的数目。

Sample Input

6 2 3
1 2
1 5
2 3
3 4
4 5
5 6

Sample Output

3

Data Constraint

对于\(30%\)的数据\(1<=N<=3000\),
对于\(50%\)的数据\(1<=N<=10000\),\(1<=M<=1000\)
对于\(100%\)的数据\(1<=M<=N<=30000\),\(0<=D<=N-1\)

Solution

我们先用一个类似于求树的直径的方法,求出距离最远的两个手机店\(A\)\(B\)
于是假设一个点\(x0\),对于每个点\(x\)都要使\(Dist{x,x0}<=D\),等价于\(Dist{A,x0}<=D\)\(Dist{B,x0}<=D\)
所以对于\(A\)\(B\),将求出他们分别距离小于\(D\)的点,然后求一个交集
然后就是答案啦!

Code

#include <iostream>
#define MAXN 30010
using namespace std;
struct rec{
    int ver, nxt;
} t[MAXN << 1];
int cnt, head[MAXN], node, Node, Max, n, m, d, a[MAXN], b[MAXN], flag[MAXN], ans, u, v;

inline int read(){
    int s = 0, w = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
    for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
    return s * w;
}

void add(int u, int v) {
    t[++cnt] = (rec){v, head[u]}, head[u] = cnt;
}

void DFS(int u, int fa, int sum){
    if (sum > Max && flag[u]) Max = sum, node = u;
    for (register int i = head[u]; i; i = t[i].nxt) {
        int v = t[i].ver;
        if (v != fa) DFS(v, u, sum + 1);
    }
}

void DFS1(int u, int fa, int sum){
    a[u] = 1;
    if (sum == d) return;
    for (register int i = head[u]; i; i = t[i].nxt) {
        int v = t[i].ver;
        if (v != fa) DFS1(v, u, sum + 1);
    }
}

void DFS2(int u, int fa, int sum){
    b[u] = 1;
    if (sum == d) return;
    for (register int i = head[u]; i; i = t[i].nxt) {
        int v = t[i].ver;
        if (v != fa) DFS2(v, u, sum + 1);
    }
}

int main(){
    //freopen("pb.in", "r", stdin);
    //freopen("pb.out", "w", stdout);
    n = read(), m = read(), d = read();
    flag[node = read()] = 1;
    for (register int i = 2; i <= m; i++)
        flag[read()] = 1;
    for (register int i = 1; i <= n - 1; i++)
        u = read(), v = read(), add(u, v), add(v, u);
    Max = -1;
    DFS(node, 0, 0);
    Max = -1, Node = node; 
    DFS(node, 0, 0);
    DFS1(node, 0, 0);
    DFS2(Node, 0, 0);
    for (register int i = 1; i <= n; i++)
        if (a[i] && b[i]) ans++;
    printf("%d\n", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Agakiss/p/11605962.html