codeforces 894D

(对二叉树有感觉的话,思路还是出得蛮快的)
题意:给定一个 n 节点二叉树( n<106 ),每条边上有一个权值,然后给出 m m<105 )个询问,求 Ai 节点在 Hi 距离内能到达的树上哪些节点,求到达这些节点经过的距离之和。

思路:如果每次处理询问都要对树上节点遍历的话,肯定超时,于是想到预处理,由于这个树是很标准的二叉树,那么我们可以试试对于每个树节点,求出每个子节点距离它的值,并保存起来(就像线段树中的push_up操作一样)。然后对于每一个询问 Ai ,我们可以利用二分查找迅速求出这个节点所能到达它下面子树节点的距离和,然后想到达剩下的节点无非要经过 Ai/2 或者 Ai1 这两个节点。然后我们也可以迅速得到 Ai1 下面子树的结果。对于 Ai/2 上面的点,我们依次遍历这些店直到树根,然后每次处理这些点对应的 Ai1 子树即可。

时间复杂度:预处理+查询: O(nlog(n)+mlog(n)2)
空间复杂度:预处理的结果: O(nlog(n))

(开始以为可能会MLE,然后发现题目中内存比较大。好久没写代码,写的时候调了半天(居然以为break语句能跳出它上面的if-else语句…)。)

#include <cstdio>
#include <algorithm>
#include <vector>
#define LL long long

using namespace std;
const int maxn = 1000050;
const int inf = 0x3f3f3f3f;

int len[maxn][2];
vector<int> vec[maxn];
vector<LL> sum[maxn];

void push_up(int p, int n) {
    // merge sort
    int l = p<<1, r = p<<1|1, i = 0, j = 0;
    int l_len = len[p][0], r_len = r<=n? len[p][1]:0;
    while(i < (int)vec[l].size()-1 && j < (int)vec[r].size()-1) {
        if(vec[l][i]+l_len < vec[r][j]+r_len)
            vec[p].push_back(vec[l][i++]+l_len);
        else
            vec[p].push_back(vec[r][j++]+r_len);
        if(vec[p].back() >= inf) break ;
    }
    while(i < (int)vec[l].size()-1 && vec[p].back() < inf)
        vec[p].push_back(vec[l][i++]+l_len);
    while(j < (int)vec[r].size()-1 && vec[p].back() < inf)
        vec[p].push_back(vec[r][j++]+r_len);
    if(vec[p].back() >= inf) vec[p].pop_back();
    // calculate prefix sum
    sum[p].push_back(vec[p][0]);
    for(int k=1; k<(int)vec[p].size(); k++) {
        sum[p].push_back(vec[p][k]);
        sum[p][k] += sum[p][k-1];
    }
    return ;
}

void build(int p, int n) {
    int lson = p<<1, rson = p<<1|1;
    if(lson <= n) build(lson, n);
    if(rson <= n) build(rson, n);
    vec[p].clear();
    vec[p].push_back(0);
    if(lson <= n || rson <= n)
        push_up(p, n);
    vec[p].push_back(inf);
    return ;
}

LL get_sum(int n, int id, int h) {
    LL ret = h;
    int pos = upper_bound(vec[id].begin(), vec[id].end(), h) - vec[id].begin() - 1;
    if(pos >= 1) ret += (LL)h*pos - sum[id][pos];
    //printf("ret1 : %I64d\n",ret);
    while(id != 1) {
        h -= len[id/2][id&1];
        if(h > 0) ret += h;
        else break ;
        //printf("ret2 : %I64d\n",ret);
        int id_2 = id ^ 1, branch = len[id/2][id_2&1];
        if(id_2 <= n && h-branch > 0) {
            ret += h-branch;
            int pos = upper_bound(vec[id_2].begin(), vec[id_2].end(), h-branch) - vec[id_2].begin() - 1;
            if(pos >= 1) ret += (LL)(h-branch)*pos - sum[id_2][pos];
            //printf("ret3 : %I64d\n",ret);
        }
        id /= 2;
    }
    return ret;
}

int main() {
    int n, m;
    scanf("%d%d",&n,&m);
    for(int i=1; i<n; i++) {
        int t, st = (i+1)/2;
        scanf("%d",&t);
        len[st][(i+1)&1] = t;
    }
    build(1, n);
    while(m --) {
        int id, h;
        scanf("%d%d",&id,&h);
        LL ans = get_sum(n, id, h);
        printf("%I64d\n",ans);
    }
    return 0;
}
发布了40 篇原创文章 · 获赞 44 · 访问量 9万+

猜你喜欢

转载自blog.csdn.net/Site1997/article/details/78645249