uoj#388. 【UNR #3】配对树(线段树合并)

传送门

先考虑一个贪心,对于一条边来说,如果当前这个序列中在它的子树中的元素个数为奇数个,那么这条边就会被一组匹配经过,否则就不会

考虑反证法,如果在这条边两边的元素个数都是偶数,那么至少有两组匹配经过它,那么把这两条路径都删去这条边可以更优。如果两边是奇数,一定至少有一条路径经过它,去掉这组匹配之后就变成了偶数的情况。证毕

然后是一个神仙的转化,我们对于一颗子树中的元素,在序列里标记为\(1\),否则为\(0\),那么这条边出现次数就是序列中长度为偶数且区间和为奇数的区间个数

考虑用线段树合并优化,对于每个节点,记\(t[p][0/1][0/1]\)表示节点\(p\)代表的区间中前缀和为偶数\(/\)奇数,下标为偶数\(/\)奇数的下标个数,然后线段树合并就行了

然而咱还是搞不明白为啥线段树上的区间要设为\([1,m+1]\)……有哪位知道为什么的请告诉咱一声……

//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
    R int res,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
const int N=1e5+5,M=N<<5,P=998244353;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
    R int res=1;
    for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
    return res;
}
struct eg{int v,nx,w;}e[N<<1];int head[N],tot;
inline void add_edge(R int u,R int v,R int w){e[++tot]={v,head[u],w},head[u]=tot;}
int sum[M],ls[M],rs[M],t[M][2][2],rt[N];
int n,m,ans,cnt,u,v,w;
void upd(int p,int l,int r){
    sum[p]=0;
    if(ls[p])sum[p]+=sum[ls[p]];
    if(rs[p])sum[p]+=sum[rs[p]];
    int x=ls[p]?sum[ls[p]]&1:0;
    fp(i,0,1)fp(j,0,1){
        t[p][i][j]=0;
        if(ls[p])t[p][i][j]+=t[ls[p]][i][j];
        if(rs[p])t[p][i][j]+=t[rs[p]][i^x][j];
    }
    int mid=(l+r)>>1;
    if(!ls[p])t[p][0][0]+=(mid>>1)-((l-1)>>1),t[p][0][1]+=((mid+1)>>1)-(l>>1);
    if(!rs[p])t[p][x][0]+=(r>>1)-(mid>>1),t[p][x][1]+=((r+1)>>1)-((mid+1)>>1);
}
void ins(int &p,int l,int r,int x){
    if(!p){
        p=++cnt;
        t[p][0][0]=(r>>1)-((l-1)>>1);
        t[p][0][1]=((r+1)>>1)-(l>>1);
    }
    if(l==r)return ++sum[p],void();
    int mid=(l+r)>>1;
    x<=mid?ins(ls[p],l,mid,x):ins(rs[p],mid+1,r,x);
    upd(p,l,r);
}
int merge(int x,int y,int l,int r){
    if(!x||!y)return x|y;
    int mid=(l+r)>>1;
    ls[x]=merge(ls[x],ls[y],l,mid);
    rs[x]=merge(rs[x],rs[y],mid+1,r);
    upd(x,l,r);
    return x;
}
void dfs(int u,int fa){
    go(u)if(v!=fa){
        dfs(v,u);
        ans=add(ans,mul(e[i].w,1ll*t[rt[v]][0][0]*t[rt[v]][1][0]%P+1ll*t[rt[v]][0][1]*t[rt[v]][1][1]%P));
        rt[u]=merge(rt[u],rt[v],1,m+1);
    }
}
int main(){
//  freopen("testdata.in","r",stdin);
    n=read(),m=read();
    fp(i,1,n-1)u=read(),v=read(),w=read(),add_edge(u,v,w),add_edge(v,u,w);
    fp(i,1,m)u=read(),ins(rt[u],1,m+1,i);
    dfs(1,0);
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/bztMinamoto/p/10286479.html