【树上点分治2】 D Tree HDU - 4812

原题链接

在这里插入图片描述
题意:给定N个点,找出乘积为k的路径,输出字典序最小的点对

考虑用哈希的做法,同样存从根到所有子节点的路径积。

考虑a*b = k,可以转换成a = k / b,但由于考虑MOD,可以先预处理出所有逆元。

#pragma comment(linker,"/STACK:102400000,102400000")
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <stack>
#include <cmath>
#include <bitset>
#include <map>
using namespace std;
//#define ACM_LOCAL
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const int N = 2e5 + 5;
const int INF = 0x3f3f3f3f;
const int MOD = 1e6 + 3;
int n, m, cnt, h[N], rt, sz[N], mx[N], vis[N], sum, ans, k, a[N], id[N];
ll d[N], dep[N];
int pd[1000010];
ll inv[1000010];
int node1, node2;
struct edge{
    
    
    int to, next;
}e[N<<1];

void add(int u, int v) {
    
    
    e[cnt].to = v;
    e[cnt].next = h[u];
    h[u] = cnt++;
}

void get_rt(int x, int fa) {
    
    
    sz[x] = 1, mx[x] = 0;
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y] || y == fa) continue;
        get_rt(y, x);
        sz[x] += sz[y];
        mx[x] = max(mx[x], sz[y]);
    }
    mx[x] = max(mx[x], sum - sz[x]);
    if (mx[x] < mx[rt]) rt = x;
}

void get_d(int x, int fa) {
    
    
    d[++d[0]] = dep[x];
    id[d[0]] = x;
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y] || y == fa) continue;
        dep[y] = dep[x] * a[y] % MOD;
        get_d(y, x);
    }
}

void cal(int x, int fa) {
    
    
    queue<int> que;
    dep[x] = a[x];
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y] || y == fa) continue;
        dep[y] = a[y], d[0] = 0;
        get_d(y, -1);
        for (int j = 1; j <= d[0]; j++) {
    
    
            int temp = (1ll*k * inv[d[j] * dep[x] % MOD]) % MOD;
            if (pd[temp]) {
    
    
                if (min(id[j], pd[temp]) < min(node1, node2)) {
    
    
                    node1 = id[j], node2 = pd[temp];
                    if (node1 > node2) swap(node1, node2);
                }
                else if (min(id[j], pd[temp]) == min(node1, node2) && max(id[j], pd[temp]) <= max(node1, node2)) {
    
    
                    node1 = id[j], node2 = pd[temp];
                    if (node1 > node2) swap(node1, node2);
                }
            }
        }
        for (int j = 1; j <= d[0]; j++) {
    
    
            if (!pd[d[j]])
                pd[d[j]] = id[j];
            else
                pd[d[j]] = min(pd[d[j]], id[j]);
            que.push(d[j]);
        }
    }
    while (que.size()) {
    
    
        pd[que.front()] = 0;
        que.pop();
    }
}

void work(int x) {
    
    
    vis[x] = 1, pd[1] = x;
    cal(x, -1);
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y]) continue;
        sum = sz[y], rt = 0;
        get_rt(y, -1);
        work(rt);
    }
}

void get_inv() {
    
    
    inv[1] = 1;
    for(int i = 2; i <= MOD; i++)
        inv[i] = ((MOD - MOD / i) * inv[MOD % i] + MOD) % MOD;
}

void solve () {
    
    
    get_inv();
    while (~scanf("%d %d", &n, &k)) {
    
    
        memset(h, -1, sizeof h);
        memset(vis, 0, sizeof vis);
        node1 = node2 = INF;
        cnt = 0;
        for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
        for (int i = 1; i <= n-1; i++) {
    
    
            int x, y;
            scanf("%d %d", &x, &y);
            add(x, y);
            add(y, x);
        }
        rt = 0, sum = n, mx[0] = INF, ans = 0;
        get_rt(1, -1);
        work(rt);
        if (node1 == INF && node2 == INF) printf("No solution\n");
        else printf("%d %d\n", node1, node2);
        //for (int i = 0; i <= MOD; i++) if (pd[i] != 0) cout << i << endl;
    }
}

int main() {
    
    
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
#ifdef ACM_LOCAL
    freopen("input", "r", stdin);
    freopen("output", "w", stdout);
#endif
    solve();
}

猜你喜欢

转载自blog.csdn.net/kaka03200/article/details/109406599