2020第一届辽宁省赛E.线段树 ——exgcd + 逆元 + 线段树

题目链接

题意: 中文题

思路:

题目要求维护区间两两数的乘积,可以转化为维护区间的平方和。

需要用到逆元

// Decline is inevitable,
// Romance will last forever.
//#include <bits/stdc++.h>
#include <iostream>
#include <cmath>
#include <cstring>
#include <string>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <deque>
#include <vector>
using namespace std;
#define mst(a, x) memset(a, x, sizeof(a))
#define INF 0x3f3f3f3f
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define ll long long
#define int long long
const int maxn = 1e2 + 10;
const int maxm = 1e3 + 10;
//const int P = 1e4 + 7;
int P;
int a[maxn];
int n;
ll power(ll a, ll b) {
    ll ans = 1 % P;
    for (; b; b >>= 1) {
        if (b & 1)  ans = ans * a % P;
        a = a * a % P;
    }
    return ans;
}
struct segement_tree {
    int l, r;
    int sum; //分别存储区间和
    int sum2, sum3; //区间平方和,立方和
    int add, mul, set;   //lazy标记 分别对应加法,乘法,变成
#define sum(x) tree[x].sum
#define sum2(x) tree[x].sum2
#define sum3(x) tree[x].sum3
#define add(x) tree[x].add
#define mul(x) tree[x].mul
#define set(x) tree[x].set
#define l(x) tree[x].l
#define r(x) tree[x].r
#define ls p<<1, l, mid
#define rs p<<1, mid + 1, r
}tree[maxn << 2];
void push(int p) {
    sum(p) = (sum(p<<1) + sum(p<<1|1)) % P;  //不要忘记mod
    sum2(p) = (sum2(p<<1) + sum2(p<<1|1)) % P;
}
void build(int p, int l, int r) {
    l(p) = l; r(p) = r;
    mul(p) = 1;     //赋初始值
    set(p) = add(p) = 0;
    if(l == r) {
        sum(p) = a[l] % P;  //不要忘记mod
        sum2(p) = a[l] * a[l] % P;
        return;
    }
    int mid = (l + r) >> 1;
    build(p<<1, l, mid);
    build(p<<1|1, mid+1, r);
    push(p);
}
void spread(int p) {
    if(mul(p) != 1) {
        mul(p<<1) = mul(p) * mul(p<<1) % P;
        mul(p<<1|1) = mul(p) * mul(p<<1|1) % P;
        if(add(p<<1))
            add(p<<1) = add(p<<1) * mul(p) % P;
        if(add(p<<1|1))
            add(p<<1|1) = add(p<<1|1) * mul(p)%P;
        sum(p<<1) = sum(p<<1) * mul(p) % P;
        sum(p<<1|1) = sum(p<<1|1)*mul(p)%P;
        sum2(p<<1) = sum2(p<<1)*mul(p)%P*mul(p)%P;
        sum2(p<<1|1) = sum2(p<<1|1)*mul(p)%P*mul(p)%P;
        mul(p) = 1;
    }
    if(add(p)) {
        add(p<<1) = (add(p<<1)+add(p))%P;
        add(p<<1|1) = (add(p<<1|1)+add(p))%P;   //
        ll lenl = r(p<<1)-l(p<<1)+1;
        ll lenr = r(p<<1|1)-l(p<<1|1)+1;
        sum2(p<<1)=(sum2(p<<1)+(add(p)*add(p)%P)*lenl%P+2*sum(p<<1)*add(p)%P)%P;
        sum2(p<<1|1)=(sum2(p<<1|1)+(add(p)*add(p)%P)*lenr%P+2*sum(p<<1|1)*add(p)%P)%P;
        sum(p<<1) = (sum(p<<1) + lenl * add(p)%P)% P;
        sum(p<<1|1) = (sum(p<<1|1) + lenr * add(p)%P)% P;
        add(p) = 0;
    }
}
void mul_change(int p, int l, int r, int k) {
    if(l <= l(p) && r >= r(p)) {
        mul(p) = mul(p) * k % P;
        sum(p) = sum(p) * k % P;
        add(p) = add(p) * k % P;
        sum2(p) = sum2(p) * k % P * k % P;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid) mul_change(p<<1, l, r, k);
    if(r > mid) mul_change(p<<1|1, l, r, k);
    push(p);
}
void add_change(int p, int l, int r, int d) {
    if(l <= l(p) && r >= r(p)) {
        add(p) = (add(p) + d) % P;
        sum2(p)=(sum2(p)+(d*d%P* (r(p)-l(p)+1))%P + 2 * sum(p) * d) % P;
        sum(p) = (sum(p) + (r(p)-l(p)+1)*d) % P;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid) add_change(p<<1, l, r, d);
    if(r > mid) add_change(p<<1|1, l, r, d);
    push(p);
}
void set_change(int p, int l, int r, int d) {
    if(l <= l(p) && r >= r(p)) {
        set(p) = d;
        add(p) = 0;     //清空标记
        mul(p) = 1;
        int len = r(p) - l(p) + 1;
        sum(p) = d * len;
        sum2(p) = len*d%P*d%P;
        sum3(p) = len*d%P*d%P*d%P;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid)
        set_change(p<<1, l, r, d);
    if(r > mid)
        set_change(p<<1|1, l, r, d);
    push(p);
 }
int query(int p, int l, int r, int index) {     //index=1 2
    if(l <= l(p) && r >= r(p)) {
        if(index == 1)
            return sum(p) % P;
        else
            return sum2(p) % P;
    }
    spread(p);
    int mid = (l(p) + r(p)) >> 1;
    int ans = 0;
    if(l <= mid) ans += query(p<<1, l, r, index);
    if(r > mid) ans += query(p<<1|1, l, r, index);
    return ans % P;     //不要忘记取mod
}

int m;
void solve() {
    cin >> n >> m >> P;
    bool ok = false;
    if(P == 2) {
        ok = true;
        P = 3;
    }
    for(int i = 1; i <= n; i++)
    {
        cin >> a[i];
        a[i] %= P;
    }
    build(1, 1, n);
    while(m--) {
        int ope, l, r, v;
        cin >> ope >> l >> r;
        if(ope == 3) {
            ll ans1 = query(1, l, r, 1) * query(1, l, r, 1) % P;
            ll ans2 = query(1, l, r, 2) % P;
            ll ans;
            if(ok) {
                ans=((ans1-ans2)%P+P)%P;
                if(ans)
                    ans=1;
            }
            else
                ans=((ans1-ans2)%P+P)%P*power(2,P-2)%P;
            cout << ans << endl;
        }
        else {
            cin >> v;
            v %= P;
            if(ope == 1)
                add_change(1, l, r, v);
            else
                mul_change(1, l, r, v);
        }
    }
}
signed main() {
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//    int T; scanf("%d", &T); while(T--)
//    freopen("1.txt","r",stdin);
//    freopen("output.txt","w",stdout);
    int T; cin >> T;while(T--)
    solve();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/m0_59273843/article/details/120856647
今日推荐