NOI模拟赛 T3 计算 calculating (线段树优化DP)

题意

轴上有\(m\)个点,有\(n\)个区间,每个区间可以选或者不选,求可以将所有点覆盖的方案数。

题解

在这里插入图片描述

CODE

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500005;
const int mod = 1000000009;
inline void read(int &x) {
    int flg = 1; char ch; while(!isdigit(ch=getchar())) if(ch=='-')flg=-flg;
    for(x = ch-'0'; isdigit(ch=getchar()); x = x*10+ch-'0'); x *= flg;
}
int n, m, b[MAXN];
struct node { int l, r; }a[MAXN];
inline bool cmp(node A, node B) { return A.l < B.l || (A.l == B.l && A.r < B.r); }
int lz[MAXN<<2], v[MAXN<<2];
inline void pd(int i) {
    if(lz[i] != 1) {
        lz[i<<1] = 1ll * lz[i<<1] * lz[i] % mod;
        v[i<<1] = 1ll * v[i<<1] * lz[i] % mod;
        lz[i<<1|1] = 1ll * lz[i<<1|1] * lz[i] % mod;
        v[i<<1|1] = 1ll * v[i<<1|1] * lz[i] % mod;
        lz[i] = 1;
    }
}
void add(int i, int l, int r, int x, int val) {
    (v[i] += val) %= mod;
    if(l == r) return;
    pd(i);
    int mid = (l + r) >> 1;
    if(x <= mid) add(i<<1, l, mid, x, val);
    else add(i<<1|1, mid+1, r, x, val);
}
void db(int i, int l, int r, int x) {
    if(x <= l) {
        lz[i] = 2ll*lz[i]%mod;
        v[i] = 2ll*v[i]%mod;
        return;
    }
    pd(i);
    int mid = (l + r) >> 1;
    if(x <= mid) db(i<<1, l, mid, x);
    db(i<<1|1, mid+1, r, x);
    v[i] = (v[i<<1] + v[i<<1|1]) % mod; //忘模  看了好久
}
int qry(int i, int l, int r, int x, int y) {
    if(x <= l && r <= y) return v[i];
    pd(i);
    int mid = (l + r) >> 1, re = 0;
    if(x <= mid) (re += qry(i<<1, l, mid, x, y))%=mod;
    if(y > mid) (re += qry(i<<1|1, mid+1, r, x, y))%=mod;
    return re;
}
void build(int i, int l, int r) {
    lz[i] = 1;
    if(l == r) return;
    int mid = (l + r) >> 1;
    build(i<<1, l, mid);
    build(i<<1|1, mid+1, r);
}
int main () {
    freopen("calculating.in", "r", stdin);
    freopen("calculating.out", "w", stdout);
    read(n), read(m);
    for(int i = 1; i <= n; ++i) read(a[i].l), read(a[i].r);
    for(int i = 1; i <= m; ++i) read(b[i]); b[++m] = 0;
    sort(b + 1, b + m + 1);
    m = unique(b + 1, b + m + 1) - b - 1;
    sort(a + 1, a + n + 1, cmp);
    build(1, 1, m); add(1, 1, m, 1, 1);
    for(int i = 1; i <= n; ++i) {
        a[i].r = upper_bound(b + 1, b + m + 1, a[i].r) - b - 1;
        a[i].l = lower_bound(b + 1, b + m + 1, a[i].l) - b;
        if(a[i].l > a[i].r) db(1, 1, m, 1);
        else db(1, 1, m, a[i].r), add(1, 1, m, a[i].r, qry(1, 1, m, a[i].l-1, a[i].r-1));
    }
    printf("%d\n", qry(1, 1, m, m, m));
}

猜你喜欢

转载自www.cnblogs.com/Orz-IE/p/12149358.html