题目链接:http://codeforces.com/contest/981/problem/G
线段树维护乘法和加法,对于每次修改,首先将区间[l,r]都乘2,然后查询[l,r]当中有哪些子区间没有出现过x,对这部分区间乘INV2(2的逆元)并且加1,查询子区间可以用set维护
代码:
#include<bits/stdc++.h> #define mp make_pair #define xx first #define yy second using namespace std; const int MAXN=2e5+5; const int MOD=998244353; const int INV2=MOD-MOD/2; typedef long long ll; typedef pair<int,int> pii; template<typename Tp> struct Seg { #define lson l,mid,rt<<1 #define rson mid+1,r,rt<<1|1 #define root l,r,rt Tp tr[MAXN<<2],w[MAXN<<2],c[MAXN<<2]; inline void push_up(int rt) { tr[rt]=(tr[rt<<1]+tr[rt<<1|1])%MOD; } inline void push_down(int l,int r,int rt) { if(w[rt]!=1) { tr[rt<<1]=(tr[rt<<1]*w[rt])%MOD; tr[rt<<1|1]=(tr[rt<<1|1]*w[rt])%MOD; c[rt<<1]=(c[rt<<1]*w[rt])%MOD; c[rt<<1|1]=(c[rt<<1|1]*w[rt])%MOD; w[rt<<1]=(w[rt<<1]*w[rt])%MOD; w[rt<<1|1]=(w[rt<<1|1]*w[rt])%MOD; w[rt]=1; } if(c[rt]) { int mid=(l+r)>>1; tr[rt<<1]=(tr[rt<<1]+(mid-l+1)*c[rt]%MOD)%MOD; tr[rt<<1|1]=(tr[rt<<1|1]+(r-mid)*c[rt]%MOD)%MOD; c[rt<<1]=(c[rt<<1]+c[rt])%MOD; c[rt<<1|1]=(c[rt<<1|1]+c[rt])%MOD; c[rt]=0; } } void build(int l,int r,int rt) { tr[rt]=0; c[rt]=0;w[rt]=1; if(l==r) return; int mid=(l+r)>>1; build(lson); build(rson); push_up(rt); } void update(int L,int R,Tp mul,Tp add,int l,int r,int rt) { if(L>R) return ; if(L<=l&&r<=R) { tr[rt]=(tr[rt]*mul)%MOD; tr[rt]=(tr[rt]+(r-l+1)*add%MOD)%MOD; w[rt]=(w[rt]*mul)%MOD; c[rt]=(c[rt]*mul)%MOD; c[rt]=(c[rt]+add)%MOD; return ; } push_down(l,r,rt); int mid=(l+r)>>1; if(L<=mid) update(L,R,mul,add,lson); if(mid<R) update(L,R,mul,add,rson); push_up(rt); } Tp query(int L,int R,int l,int r,int rt) { if(L>R) return 0; if(L<=l&&r<=R) return tr[rt]; push_down(l,r,rt); int mid=(l+r)>>1; Tp ret=0; if(L<=mid) ret+=query(L,R,lson); if(mid<R) ret+=query(L,R,rson); ret%=MOD; return ret; } }; Seg<ll> se; set<pii> seg[MAXN]; void split(int l,int x) { auto it=lower_bound(seg[x].begin(),seg[x].end(),mp(l,l)); if(it==seg[x].begin()) return; it--; pii tmp=*(it); if(tmp.yy>=l) { seg[x].erase(it); if(tmp.xx<=l-1) seg[x].insert(mp(tmp.xx,l-1)); seg[x].insert(mp(l,tmp.yy)); } } inline char nc() { static char buf[100000],*p1=buf,*p2=buf; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } inline void rea(int &x) { char c=nc();x=0; for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc()); } int main() { //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); int n,q; //scanf("%d%d",&n,&q); rea(n);rea(q); se.build(1,n,1); for(int i=1;i<=n;i++) seg[i].insert(mp(1,n)); while(q--) { int op,l,r,x; //scanf("%d",&op); rea(op); if(op==1) { //scanf("%d%d%d",&l,&r,&x); rea(l);rea(r);rea(x); se.update(l,r,2,0,1,n,1); split(l,x);split(r+1,x); while(1) { auto it=lower_bound(seg[x].begin(),seg[x].end(),mp(l,l)); if(it==seg[x].end()||it->xx>r) break; se.update(it->xx,it->yy,INV2,1,1,n,1); seg[x].erase(it); } } else { //scanf("%d%d",&l,&r); rea(l);rea(r); printf("%lld\n",se.query(l,r,1,n,1)); } } return 0; }