【POJ3667】Hotel

一道线段树区间合并的问题。

我们需要每次用线段树查找是否存在一个连续的为0的且长度不短与x的子序列、查找完成后需要返回这个子序列的左端点,并且将该子序列全部赋值为1,还需要用线段树完成对一个子序列的赋值。

我们在线段树的每一个节点上维护四个量:tag,sum,l,r分别表示该区间的值(0表示全部为0、1表示全部为1、-1表示其他情况)、这个区间最大连续为0的子序列(下文简称合法序列)的长度、以这个区间左端点为起点的合法序列的长度、以这个区间右端点为起点的合法序列的长度、这个区间的合法序列的长度。

关于这四个量的维护,可以参阅这里

我们重点讨论一下查询操作。设查询的长度为x,如果整个序列的sum小于x则无解,直接输出0即可。否则,如果当前区间的左子区间的合法序列的长度大于等于x,那么递归左子区间,在左子区间中找答案。如果答案区间横跨左右子区间,我们特殊处理。剩余的情况,我们到其右子区间寻找答案即可。

 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 #include <algorithm>
 5 using namespace std;
 6 struct node {
 7     int l,r,sum,tag;
 8 }a[50010<<2];
 9 int n,m;
10 inline void pushup(int now,int l,int r) {
11     a[now].l=a[now<<1].tag==0?a[now<<1].l+a[now<<1|1].l:a[now<<1].l;
12     a[now].r=a[now<<1|1].tag==0?a[now<<1|1].r+a[now<<1].r:a[now<<1|1].r;
13     a[now].sum=max(a[now<<1].r+a[now<<1|1].l,max(a[now<<1].sum,a[now<<1|1].sum));
14     if(a[now].sum==0) a[now].tag=1;
15     else if(a[now].sum==r-l+1) a[now].tag=0;
16     else a[now].tag=-1;
17 }
18 inline void pushdown(int now,int l,int r) {
19     if(a[now].tag==-1) return ;
20     a[now<<1].tag=a[now<<1|1].tag=a[now].tag;
21     a[now<<1].l=a[now<<1].r=a[now<<1].sum=a[now<<1].tag==0?((r-l+1)-(r-l+1>>1)):0;
22     a[now<<1|1].l=a[now<<1|1].r=a[now<<1|1].sum=a[now<<1|1].tag==0?(r-l+1>>1):0;
23 }
24 inline void build(int now,int l,int r) {
25     if(l==r) {
26         a[now].tag=0;
27         a[now].l=a[now].r=a[now].sum=1;
28         return ;
29     }
30     int mid=l+r>>1;
31     build(now<<1,l,mid);
32     build(now<<1|1,mid+1,r);
33     pushup(now,l,r);
34 }
35 void updata(int now,int l,int r,int x,int y,int val) {
36     if(x<=l&&r<=y) {
37         a[now].tag=val;
38         a[now].l=a[now].r=a[now].sum=val==0?(r-l+1):0;
39         return ;
40     }
41     int mid=l+r>>1;
42     pushdown(now,l,r);
43     if(x<=mid) updata(now<<1,l,mid,x,y,val);
44     if(y>mid) updata(now<<1|1,mid+1,r,x,y,val);
45     pushup(now,l,r);
46 }
47 int query(int now,int l,int r,int x) {
48     if(l==r) return l;
49     pushdown(now,l,r);
50     int mid=l+r>>1;
51     if(a[now<<1].sum>=x) return query(now<<1,l,mid,x);
52     else if(a[now<<1].r+a[now<<1|1].l>=x) return mid-a[now<<1].r+1;
53     return query(now<<1|1,mid+1,r,x);
54 }
55 int main() {
56     scanf("%d%d",&n,&m);
57     build(1,1,n);
58     while(m--) {
59         int op,x,y;
60         scanf("%d",&op);
61         if(op==1) {
62             scanf("%d",&x);
63             if(a[1].sum<x) {
64                 puts("0");
65                 continue ;
66             }
67             int ans=query(1,1,n,x);
68             updata(1,1,n,ans,ans+x-1,1);
69             printf("%d\n",ans);
70         }
71         else {
72             scanf("%d%d",&x,&y);
73             updata(1,1,n,x,x+y-1,0);
74         }
75     }
76     return 0;
77 }
AC Code

猜你喜欢

转载自www.cnblogs.com/shl-blog/p/10922358.html