Test 1 T2 B 线段树合并

模拟赛的T2,多敲了两行成功爆掉~

写线段树合并的时候一定要注意一下不能随意新开节点. 

code: 

#include <bits/stdc++.h>  
#define N 100009 
#define ll long long  
#define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)   
using namespace std;   
int n,edges;  
int A[N],hd[N],to[N<<1],nex[N<<1],kk[N],rt[N],ans1[N]; 
ll val[N<<1]; 
ll ans2[N];   
void add(int u,int v,int c) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c;   
}   
struct Segment_Tree
{   
    #define lson p[x].ls 
    #define rson p[x].rs 
    int tot;   
    struct Node 
    {    
        int ls,rs,size;    
        ll dis,num,rt,mul,maxx;        
    }p[N*90];  
    int newnode() { return ++tot; }        
    void mark(int x,ll d) 
    {
        p[x].mul+=d;   
        p[x].rt+=d*p[x].num;                 
    }   
    void pushdown(int x,int l,int r) 
    {
        if(p[x].mul) 
        {
            int mid=(l+r)>>1;               
            if(lson) mark(lson, p[x].mul);     
            if(rson) mark(rson, p[x].mul);   
            p[x].mul=0;               
        }
    }     
    int merge(int l,int r,int u,int v) 
    {
        if(!u||!v) return u+v;                       
        pushdown(u,l,r);                                  
        pushdown(v,l,r);              
        int now=newnode();     
        p[now].dis=p[u].dis+p[v].dis+p[u].num*p[v].rt+p[u].rt*p[v].num;    
        p[now].num=p[u].num+p[v].num;   
        p[now].rt=p[u].rt+p[v].rt;            
        p[now].maxx=max(p[u].maxx, p[v].maxx);                                                    
        if(l==r) 
        {
            if(p[now].num) p[now].size=1;   
            p[now].maxx=p[now].dis;     
            return now;   
        }
        int mid=(l+r)>>1;    
        p[now].ls=merge(l,mid,p[u].ls,p[v].ls);   
        p[now].rs=merge(mid+1,r,p[u].rs,p[v].rs);     
        p[now].size=p[p[now].ls].size+p[p[now].rs].size;      
        p[now].maxx=max(p[p[now].ls].maxx, p[p[now].rs].maxx);   
        return now;   
    }    
    int solve(int l,int r,int x) 
    { 
        if(l==r) return l; 
        int mid=(l+r)>>1;               
        pushdown(x,l,r);   
        if(l<=mid && p[lson].size && p[lson].maxx==p[x].maxx) return solve(l,mid,lson);  
        else return solve(mid+1,r,rson);   
    }
    void update(int &x,int l,int r,int pp) 
    {
        if(!x) x=newnode(); 
        if(l==r) 
        {   
            p[x].size=1; 
            p[x].num=1; 
            return;  
        } 
        pushdown(x, l, r);  
        int mid=(l+r)>>1;  
        if(pp<=mid) update(lson,l,mid,pp); 
        else update(rson,mid+1,r,pp);  
        p[x].maxx=max(p[lson].maxx, p[rson].maxx);   
        p[x].size=p[lson].size+p[rson].size;   
    }
    ll dfss(int l,int r,int x,int kth) 
    { 
        if(l==r) return p[x].dis;  
        int mid=(l+r)>>1;  
        pushdown(x,l,r);    
        int sz=p[lson].size;   
        if(sz>=kth) return dfss(l,mid,lson,kth); 
        else return dfss(mid+1,r,rson,kth-sz);           
    }
    #undef lson 
    #undef rson 
}seg; 
void dfs(int u,int ff,int pp) 
{  
    seg.update(rt[u],1,n,A[u]);     
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
        if(v==ff) continue;   
        dfs(v,u,val[i]);  
    }                                      
    if(seg.p[rt[u]].size<kk[u]) ans2[u]=-1; 
    else
    {
        ans2[u]=seg.dfss(1,n,rt[u],kk[u]);   
    }
    ans1[u]=seg.solve(1,n,rt[u]);  
    seg.mark(rt[u], 1ll*pp);                 
    rt[ff]=seg.merge(1,n,rt[u], rt[ff]);              
}
int main() 
{ 
    int i,j; 
    // setIO("input");    
    scanf("%d",&n);  
    for(i=1;i<n;++i)         
    {
        int u,v,c; 
        scanf("%d%d%d",&u,&v,&c), add(u,v,c), add(v,u,c);   
    }
    for(i=1;i<=n;++i) scanf("%d",&A[i]);           
    for(i=1;i<=n;++i) scanf("%d",&kk[i]);       
    dfs(1,0,0);              
    for(i=1;i<=n;++i) printf("%d %lld\n",ans1[i],ans2[i]);   
    return 0; 
}

  

猜你喜欢

转载自www.cnblogs.com/guangheli/p/11635208.html