树上求和【牛客小白月赛9-D】【树链剖分】

题目链接

还不懂树链剖分的小伙伴可以看这里,以及我的学习笔记都在里面


  一道很基本的树链剖分,但是要推一个关于处理平方和的公式:X1^2+X2^2+......+Xn^2;有(x+y)^2=X^2+Y^2+2xy,那么多几个x、y也是成立的:(X1+X2+X3+......+Xn+Y)=X1^2+X2^2+......+Xn^2+2*(X1+X2+X3+......+Xn)*Y+Y^2。

然后,上代码具体看一下吧:


#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int mod = 23333;
const int maxN = 100010;
int N, Q, w[maxN], head[maxN], cnt, depth[maxN], root[maxN], size[maxN], W_son[maxN], top[maxN], id[maxN], num;
ll new_W[maxN], sum[maxN<<2], multi[maxN<<2], lazy[maxN<<2];
struct Eddge
{
    int nex, to;
    Eddge(int a=-1, int b=0):nex(a), to(b) {}
}edge[maxN<<1];
void addEddge(int u, int v)
{
    edge[cnt] = Eddge(head[u], v);
    head[u] = cnt++;
}
void dfs1(int u, int fa, int deep)
{
    root[u] = fa;
    depth[u] = deep;
    size[u] = 1;
    int maxSon = -1;
    for(int i=head[u]; i!=-1; i=edge[i].nex)
    {
        int v = edge[i].to;
        if(v == fa) continue;
        dfs1(v, u, deep+1);
        size[u] += size[v];
        if(size[v] > maxSon)
        {
            maxSon = size[v];
            W_son[u] = v;
        }
    }
}
void dfs2(int x, int topf)
{
    top[x] = topf;
    id[x] = ++num;
    new_W[num] = w[x];
    if(!W_son[x]) return;
    dfs2(W_son[x], topf);
    for(int i=head[x]; i!=-1; i=edge[i].nex)
    {
        int v = edge[i].to;
        if(v == W_son[x] || v == root[x]) continue;
        dfs2(v, v);
    }
}
void pushup(int rt)
{
    sum[rt] = ( sum[rt<<1] + sum[rt<<1|1] )%mod;
    multi[rt] = ( multi[rt<<1] + multi[rt<<1|1] )%mod;
}
void buildTree(int rt, int l, int r)
{
    lazy[rt] = 0;
    if(l == r)
    {
        sum[rt] = new_W[l]%mod;
        multi[rt] = sum[rt] * sum[rt] %mod;
        return;
    }
    int mid = (l + r)>>1;
    buildTree(rt<<1, l, mid);
    buildTree(rt<<1|1, mid+1, r);
    pushup(rt);
}
void pushdown(int rt, int l, int r)
{
    if(lazy[rt])
    {
        lazy[rt]%=mod;
        lazy[rt<<1] = ( lazy[rt<<1] + lazy[rt] )%mod;
        lazy[rt<<1|1] = ( lazy[rt<<1|1] + lazy[rt] )%mod;
        int mid = (l + r)>>1;
        multi[rt<<1] = ( multi[rt<<1] + 2*sum[rt<<1]*lazy[rt]%mod + lazy[rt]*lazy[rt]%mod*(mid - l + 1)%mod )%mod;
        multi[rt<<1|1] = ( multi[rt<<1|1] + 2*sum[rt<<1|1]*lazy[rt]%mod + lazy[rt]*lazy[rt]%mod*(r - mid)%mod )%mod;
        sum[rt<<1] = ( sum[rt<<1] + (mid - l + 1)*lazy[rt] )%mod;
        sum[rt<<1|1] = ( sum[rt<<1|1] + (r - mid)*lazy[rt] )%mod;
        lazy[rt] = 0;
    }
}
void update(int rt, int l, int r, int ql, int qr, ll val)
{
    if(ql<=l && qr>=r)
    {
        lazy[rt] = ( lazy[rt] + val )%mod;
        multi[rt] = ( multi[rt] + 2*sum[rt]*val%mod + val*val%mod*(r - l + 1)%mod );
        sum[rt] = ( sum[rt] + val*(r - l + 1)%mod )%mod;
        return;
    }
    pushdown(rt, l, r);
    int mid = (l + r)>>1;
    if(ql>mid) update(rt<<1|1, mid+1, r, ql, qr, val);
    else if(qr<=mid) update(rt<<1, l, mid, ql, qr, val);
    else
    {
        update(rt<<1, l, mid, ql, qr, val);
        update(rt<<1|1, mid+1, r, ql, qr, val);
    }
    pushup(rt);
}
ll query(int rt, int l, int r, int ql, int qr)
{
    if(ql<=l && qr>=r) return multi[rt]%mod;
    pushdown(rt, l, r);
    int mid = (l + r)>>1;
    if(ql>mid) return query(rt<<1|1, mid+1, r, ql, qr);
    else if(qr<=mid) return query(rt<<1, l, mid, ql, qr);
    else
    {
        ll ans = query(rt<<1, l, mid, ql, mid);
        ans = ( ans + query(rt<<1|1, mid+1, r, mid+1, qr) )%mod;
        return ans;
    }
}
void update_Son(int x, ll val)
{
    val%=mod;
    update(1, 1, N, id[x], id[x]+size[x]-1, val);
}
ll query_Son(int x)
{
    return query(1, 1, N, id[x], id[x]+size[x]-1);
}
void init()
{
    memset(head, -1, sizeof(head));
    cnt = num = 0;
    memset(W_son, 0, sizeof(W_son));
}
int main()
{
    while(scanf("%d%d", &N, &Q)!=EOF)
    {
        init();
        for(int i=1; i<=N; i++) scanf("%d", &w[i]);
        for(int i=1; i<N; i++)
        {
            int e1, e2;
            scanf("%d%d", &e1, &e2);
            addEddge(e1, e2);
            addEddge(e2, e1);
        }
        dfs1(1, 1, 0);
        dfs2(1, 1);
        buildTree(1, 1, N);
        while(Q--)
        {
            int op, x, y;
            scanf("%d", &op);
            if(op == 1)
            {
                scanf("%d%d", &x, &y);
                update_Son(x, y);
            }
            else
            {
                scanf("%d", &x);
                printf("%lld\n", query_Son(x));
            }
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41730082/article/details/84582306