HDU4578 线段树维护区间平方和立方和

题目链接

参考blog1

参考blog2​​​​​​​

// Decline is inevitable,
// Romance will last forever.
#include <bits/stdc++.h>
using namespace std;
#define mst(a, x) memset(a, x, sizeof(a))
#define INF 0x3f3f3f3f
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define ll long long
#define int long long
const int maxn = 1e5 + 10;
const int maxm = 1e3 + 10;
const int P = 1e4 + 7;
int a[maxn];
struct segement_tree {
    int l, r;
    int sum; //分别存储区间和
    int sum2, sum3; //区间平方和,立方和
    int add, mul, set;   //lazy标记 分别对应加法,乘法,变成
#define sum(x) tree[x].sum
#define sum2(x) tree[x].sum2
#define sum3(x) tree[x].sum3
#define add(x) tree[x].add
#define mul(x) tree[x].mul
#define set(x) tree[x].set
#define l(x) tree[x].l
#define r(x) tree[x].r
#define ls p<<1, l, mid
#define rs p<<1, mid + 1, r
}tree[maxn << 2];
void push(int p) {
    sum(p) = (sum(p<<1) + sum(p<<1|1)) % P;  //不要忘记mod
    sum2(p) = (sum2(p<<1) + sum2(p<<1|1)) % P;
    sum3(p) = (sum3(p<<1) + sum3(p<<1|1)) % P;
}
void build(int p, int l, int r) {
    l(p) = l; r(p) = r;
    mul(p) = 1;     //赋初始值
    set(p) = add(p) = 0;
    if(l == r) {
        sum(p) = a[l] % P;  //不要忘记mod
        sum2(p) = a[l] * a[l] % P;
        sum3(p) = a[l] * a[l] % P * a[l] % P;
        return;
    }
    int mid = (l + r) >> 1;
    build(p<<1, l, mid);
    build(p<<1|1, mid+1, r);
    push(p);
}
void spread(int p) {
    if(set(p)) {
        set(p<<1) = set(p<<1|1) = set(p);
        add(p<<1) = add(p<<1|1) = 0;
        mul(p<<1) = mul(p<<1|1) = 1;
        ll lenl = r(p<<1) - l(p<<1) + 1;
        ll lenr = r(p<<1|1) - l(p<<1|1) + 1;
        ll tmp = set(p) * set(p) % P * set(p) % P;
        sum(p<<1) = lenl % P * set(p) % P;
        sum(p<<1|1) = lenr % P * set(p) % P;
        sum2(p<<1) = lenl % P * set(p) % P * set(p) % P;
        sum2(p<<1|1) = lenr % P *set(p)%P*set(p)%P;
        sum3(p<<1) = lenl % P * tmp % P;
        sum3(p<<1|1) = lenr % P * tmp % P;
        set(p) = 0;
    }
    if(mul(p) != 1) {
        mul(p<<1) = mul(p) * mul(p<<1) % P;
        mul(p<<1|1) = mul(p) * mul(p<<1|1) % P;
        if(add(p<<1))
            add(p<<1) = add(p<<1) * mul(p) % P;
        if(add(p<<1|1))
            add(p<<1|1) = add(p<<1|1) * mul(p)%P;
        sum(p<<1) = sum(p<<1) * mul(p) % P;
        sum(p<<1|1) = sum(p<<1|1)*mul(p)%P;
        sum2(p<<1) = sum2(p<<1)*mul(p)%P*mul(p)%P;
        sum2(p<<1|1) = sum2(p<<1|1)*mul(p)%P*mul(p)%P;
        sum3(p<<1) = sum3(p<<1)*mul(p)%P*mul(p)%P*mul(p)%P;
        sum3(p<<1|1) = sum3(p<<1|1)*mul(p)%P*mul(p)%P*mul(p)%P;
        mul(p) = 1;
    }
    if(add(p)) {
        add(p<<1) = (add(p<<1)+add(p))%P;
        add(p<<1|1) = (add(p<<1|1)+add(p))%P;   //
        ll tmp = add(p)*add(p)%P*add(p)%P;  //注意sum3 sum2 sum更新顺序
        ll lenl = r(p<<1)-l(p<<1)+1;
        ll lenr = r(p<<1|1)-l(p<<1|1)+1;
        sum3(p<<1)=((sum3(p<<1)+(tmp*lenl%P))%P+3*add(p)*(sum2(p<<1)+sum(p<<1)*add(p)%P)%P)% P;
//        sum3(p<<1)=((sum3(p<<1)+3 * add(p) * sum2(p<<1) % P);
        sum3(p<<1|1)=((sum3(p<<1|1)+(tmp*lenr%P))%P+3*add(p)*(sum2(p<<1|1)+sum(p<<1|1)*add(p)%P)%P)%P;
        sum2(p<<1)=(sum2(p<<1)+(add(p)*add(p)%P)*lenl%P+2*sum(p<<1)*add(p)%P)%P;
        sum2(p<<1|1)=(sum2(p<<1|1)+(add(p)*add(p)%P)*lenr%P+2*sum(p<<1|1)*add(p)%P)%P;
        sum(p<<1) = (sum(p<<1) + lenl * add(p)%P)% P;
        sum(p<<1|1) = (sum(p<<1|1) + lenr * add(p)%P)% P;
        add(p) = 0;
    }
}
void mul_change(int p, int l, int r, int k) {
    if(l <= l(p) && r >= r(p)) {
        mul(p) = mul(p) * k % P;
        sum(p) = sum(p) * k % P;
        add(p) = add(p) * k % P;
        
        sum2(p) = sum2(p) * k % P * k % P;
        sum3(p) = sum3(p) * k % P * k % P * k % P;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid) mul_change(p<<1, l, r, k);
    if(r > mid) mul_change(p<<1|1, l, r, k);
    push(p);
}
void add_change(int p, int l, int r, int d) {
    if(l <= l(p) && r >= r(p)) {
        add(p) = (add(p) + d) % P;
        //注意321更新顺序
        ll tmp = (((d * d) % P * d) % P * (r(p) - l(p) + 1)) % P;    //(r - l + 1) * c^3
        sum3(p)=(sum3(p)+tmp+3*d*((sum2(p) + sum(p)*d)%P)%P)%P;
        sum2(p)=(sum2(p)+(d*d%P* (r(p)-l(p)+1))%P + 2 * sum(p) * d) % P;
        sum(p) = (sum(p) + (r(p)-l(p)+1)*d) % P;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid) add_change(p<<1, l, r, d);
    if(r > mid) add_change(p<<1|1, l, r, d);
    push(p);
}
void set_change(int p, int l, int r, int d) {
    if(l <= l(p) && r >= r(p)) {
        set(p) = d;
        add(p) = 0;     //清空标记
        mul(p) = 1;
        int len = r(p) - l(p) + 1;
        sum(p) = d * len;
        sum2(p) = len*d%P*d%P;
        sum3(p) = len*d%P*d%P*d%P;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid)
        set_change(p<<1, l, r, d);
    if(r > mid)
        set_change(p<<1|1, l, r, d);
    push(p);
 }
int query(int p, int l, int r, int index) {     //index=1 2 3 分别代表询问和 平方和 立方和
    if(l <= l(p) && r >= r(p)) {
        if(index == 1)
            return sum(p) % P;
        else if(index == 2)
            return sum2(p) % P;
        else
            return sum3(p) % P;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    int ans = 0;
    if(l <= mid) ans += query(p<<1, l, r, index);
    if(r > mid) ans += query(p<<1|1, l, r, index);
    return ans % P;     //不要忘记取mod
}


void solve() {
    int n, m;
    int a, b, c, ch;
    while(cin >> n >> m) {
        if(n == 0 && m == 0) break;
        build(1, 1, n);
        while(m--) {
            cin >> ch >> a >> b >> c;
            if(ch == 1)
                add_change(1, a, b, c);
            else if(ch == 2)
                mul_change(1, a, b, c);
            else if(ch == 3)
                set_change(1, a, b, c);
            else {
                cout << query(1, a, b, c) << endl;;
            }
        }
    }
}
signed main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    solve();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/m0_59273843/article/details/120826179