牛客算法周周练15 树上求和(dfs序线段树)

(dfs序线段树)树上求和

题目链接
题目意思明了,看起来像线段树,但是又不像线段树,线段树的单一题目一般都直接给出一个数列,然后要你求区间之间的一些值,维护一些值,但是这个题目又不是给你一个数列,而是一颗树,那么就要想办法改成线段树,这样dfs序就来了,因为dfs序中一个点的进来的位置与出去的位置之间恰好是他的子树所在区间,和题目要求的东西刚刚好。(至于为什么是这样百度一下dfs序就知道了)。
在这里插入图片描述

int in[maxn],out[maxn],cnt,vis[maxn],fir[maxn];//in进去的位置,out出来的位置,vis第cnt个位置是哪个点,fir建图的东西
struct node
{
    int to,next;
}e[maxn<<1];
int fir[maxn<<1];
void add(int u,int v)//建树
{
    e[cnt].to=v;
    e[cnt].next=fir[u];
    fir[u]=cnt++;
}
void dfs(int x,int fa)//求解dfs序
{
    in[x]=++cnt,vis[cnt]=x;
    for(int i=fir[x];i!=-1;i=e[i].next)
    {
        int v=e[i].to;
        if(fa==v)continue;
        dfs(v,x);
    }
    out[x]=cnt;
}

将普通树上求子树和和修改子树变成了线段树求解区间平方和,和区间修改问题了,区间平方和的求解十分简单就不说了,现在来说一下区间修改问题:假设一个区间的平方和是a^2 + b^2 + c^2,区间和是a + b + c,那么在这个区间上加一个数v,平方和就变成了(a+v)^2 +(b+v)^2 + (c+v)^2, 区间和就变成了a+b+c+3v,化简一下可以推出a^2 + b^2 +c^2 +2v(a+b+c)+3*v^2;然后规律就很显然了

#include<iostream>
#include<stdio.h>
#include<math.h>
#include<string.h>
#include<string>
#include<vector>
#include<queue>
#include<algorithm>
#include<deque>
#include<map>
#include<stdlib.h>
#include<set>
#include<iomanip>
#include<stack>
#define ll long long
#define ull unsigned long long
#define ms(a,b) memset(a,b,sizeof(a))
#define lowbit(x) x & -x
#define fi first
#define se second
#define lson num<<1
#define rson num<<1|1
#define bug cout<<"----acac----"<<endl
#define IOS ios::sync_with_stdio(false), cin.tie(0),cout.tie(0)
using namespace std;
const int maxn = 1e5+ 50;
const double eps = 1e-7;
const int inf = 0x3f3f3f3f;
const ll  lnf  = 0x3f3f3f3f3f3f3f3f;
const ll mod = 23333;
const  double pi=3.141592653589;
ll a[maxn];
int in[maxn],out[maxn],cnt,vis[maxn];
struct node
{
    int to,next;
}e[maxn<<1];
int fir[maxn<<1];
void add(int u,int v)
{
    e[cnt].to=v;
    e[cnt].next=fir[u];
    fir[u]=cnt++;
}
void dfs(int x,int fa)//求解dfs序
{
    in[x]=++cnt,vis[cnt]=x;
    for(int i=fir[x];i!=-1;i=e[i].next)
    {
        int v=e[i].to;
        if(fa==v)continue;
        dfs(v,x);
    }
    out[x]=cnt;
}
struct node2
{
    int l,r;
    ll sum,val,lazy;
}tr[maxn<<2];
void pushup(int num)
{
    tr[num].sum=(tr[lson].sum+tr[rson].sum)%mod;
    tr[num].val=(tr[lson].val+tr[rson].val)%mod;
}
void pushdown(int num)//这里一定要小心别写错了,不然找bug贼难
{
    if(tr[num].lazy)
    {
        tr[lson].lazy=(tr[lson].lazy+tr[num].lazy)%mod;
        tr[rson].lazy=(tr[rson].lazy+tr[num].lazy)%mod;
        int l=tr[lson].r-tr[lson].l+1,r=tr[rson].r-tr[rson].l+1;
        tr[lson].sum=(tr[lson].sum+2*tr[lson].val%mod*tr[num].lazy%mod+(l%mod*tr[num].lazy%mod*tr[num].lazy%mod))%mod;
        tr[rson].sum=(tr[rson].sum+2*tr[rson].val%mod*tr[num].lazy%mod+(r%mod*tr[num].lazy%mod*tr[num].lazy%mod))%mod;
        tr[lson].val=(tr[lson].val+tr[num].lazy*l%mod)%mod;
        tr[rson].val=(tr[rson].val+tr[num].lazy*r%mod)%mod;
        tr[num].lazy=0;
    }
}
void build(int l,int r,int num)
{
    tr[num].l=l;
    tr[num].r=r;
    if(l==r)
    {
        tr[num].val=a[vis[l]]%mod;//这个地方一定要注意不是a[num],至于为什么自己看看dfs理解后就知道了
        tr[num].sum=tr[num].val*tr[num].val%mod;
        return ;
    }
    int mid=(l+r)>>1;
    build(l,mid,lson);
    build(mid+1,r,rson);
    pushup(num);
}
void update(int l,int r,int v,int num)//更新
{
    if(tr[num].l>=l&&tr[num].r<=r)
    {
        tr[num].sum=(tr[num].sum+v%mod*v%mod*(tr[num].r-tr[num].l+1)%mod+tr[num].val*2*v%mod)%mod;
        tr[num].val=(tr[num].val+(tr[num].r-tr[num].l+1)*v%mod)%mod;
        tr[num].lazy+=v;
        return;
    }
    pushdown(num);
    int mid=(tr[num].l+tr[num].r)>>1;
    if(mid>=l)update(l,r,v,lson);
    if(mid<r)update(l,r,v,rson);
    pushup(num); 
}
ll query(int l,int r,int num)//找区间的平方和
{
    if(tr[num].l>=l&&tr[num].r<=r)
    {
        return tr[num].sum%mod;
    }
    ll  ans=0;
    pushdown(num);
    int mid=(tr[num].l+tr[num].r)>>1;
    if(mid>=l)ans=(ans+query(l,r,lson))%mod;
    if(mid<r)ans=(ans+query(l,r,rson))%mod;
    return ans%mod;
}
int main()
{
    ms(fir,-1);
    int n,q;
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&a[i]);
    }
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        add(u,v);
        add(v,u);
    }
    cnt=0;
    dfs(1,-1);
    build(1,n,1);
    for(int i=1;i<=q;i++)
    {
        int w;
        scanf("%d",&w);
        if(w==1)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            update(in[x],out[x],y,1);
        }
        else
        {
            int x;
            scanf("%d",&x);
            printf("%lld\n",query(in[x],out[x],1)%mod);
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qcccc_/article/details/107348738
今日推荐