HDU 4578 - Transformation (线段树)

Transformation


Time Limit: 15000/8000 MS (Java/Others) Memory Limit: 65535/65536 K (Java/Others)
Total Submission(s): 7166 Accepted Submission(s): 1813

Problem Description

Yuanfang is puzzled with the question below:
There are n integers, a1, a2, …, an. The initial values of them are 0. There are four kinds of operations.
Operation 1: Add c to each number between ax and ay inclusive. In other words, do transformation ak<---ak+c, k = x,x+1,…,y.
Operation 2: Multiply c to each number between ax and ay inclusive. In other words, do transformation ak<---ak×c, k = x,x+1,…,y.
Operation 3: Change the numbers between ax and ay to c, inclusive. In other words, do transformation ak<---c, k = x,x+1,…,y.
Operation 4: Get the sum of p power among the numbers between ax and ay inclusive. In other words, get the result of axp+ax+1p+…+ay p.
Yuanfang has no idea of how to do it. So he wants to ask you to help him.

Input

There are no more than 10 test cases.
For each case, the first line contains two numbers n and m, meaning that there are n integers and m operations. 1 <= n, m <= 100,000.
Each the following m lines contains an operation. Operation 1 to 3 is in this format: "1 x y c" or "2 x y c" or "3 x y c". Operation 4 is in this format: "4 x y p". (1 <= x <= y <= n, 1 <= c <= 10,000, 1 <= p <= 3)
The input ends with 0 0.

Output

For each operation 4, output a single integer in one line representing the result. The answer may be quite large. You just need to calculate the remainder of the answer when divided by 10007.

Sample Input

5 5
3 3 5 7
1 2 4 4
4 1 5 2
2 2 5 8
4 3 5 3
0 0

Sample Output

307
7489

解题思路

裸线段树题目,处理情况非常多,主要问题在于乘法加法改变值的冲突对于二次方三次方线段树的维护,对于第二个问题处理方式为将求和的多项式展开可以得到
(x+a)^2 = x^2+2ax +a^2,推得Sum2 = Sum2 + 2aSum1 + (a^2) * (区间大小), Sum3 同理。
第一个问题在于三种操作的优先级为 改值 > 乘法 > 加法, 并在打上改值标记的时候重置乘法加法的标记,打上乘法标记的时候更新加法标记的值。
ps: 标记下传的时候也要更新其他标记的值。

#define _CRT_SECURE_NO_WARNINGS
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <string>
#include <map>
#include <stack>
#define M 10007
#define ll long long 
#define se second
#define fi first
#define pb push_back
#define INF 0x3f3f3f3f
#define de(x) cout << #x << " = "<< x << endl;
#define eps 1e-6

using namespace std;
int n, m;
ll sum[4][1000005], col[1000005], add[1000005], mul[1000005];
int lol, tot;
void append(int t, int m, ll z)
{
    sum[3][t] = (sum[3][t] + (3 * z * (z * sum[1][t] + sum[2][t]) % M + ((z * z * z) % M)* m))% M;
    sum[2][t] = (sum[2][t] + (2 * z * sum[1][t]) + (z * z) * m %M) % M;
    sum[1][t] = (sum[1][t] + z * m) % M;
}

void multiply(int t, ll z)
{
    sum[3][t] = sum[3][t] * (z * z * z % M) % M;
    sum[2][t] = sum[2][t] * (z * z % M) % M;
    sum[1][t] = sum[1][t] * z % M;
}

void change(int t, int m, ll z)
{
    sum[1][t] = z * m % M;
    sum[2][t] = sum[1][t] * z % M;
    sum[3][t] = sum[2][t] * z % M;
}

void push_down(int t, int m)
{
    if (col[t])
    {
        col[t << 1] = col[t];
        col[t << 1 | 1] = col[t];
        mul[t << 1] = 1;
        mul[t << 1 | 1] = 1;
        add[t << 1] = 0;
        add[t << 1 | 1] = 0;
        change(t << 1, m - (m >> 1), col[t]);
        change(t << 1 | 1, m >> 1, col[t]);
        col[t] = 0;
    }
    if (mul[t] != 1)
    {
        mul[t << 1] = mul[t << 1] * mul[t] % M;
        mul[t << 1 | 1] = mul[t << 1 | 1] * mul[t] % M;
        if (add[t << 1]) add[t << 1] = (add[t << 1] * mul[t]) % M;
        if (add[t << 1 | 1]) add[t << 1 | 1] = (add[t << 1 | 1] * mul[t]) % M;
        multiply(t << 1, mul[t]);
        multiply(t << 1 | 1, mul[t]);
        mul[t] = 1;
    }
    if (add[t])
    {
        add[t << 1] = (add[t << 1] + add[t]) % M;
        add[t << 1 | 1] = (add[t << 1 | 1] + add[t]) % M;

        append(t << 1, m - (m >> 1), add[t]);
        append(t << 1 | 1, m >> 1, add[t]);
    
        add[t] = 0;
    }
}
 
void Add(int l, int r, int l1, int r1, int t, int z)
{
    if (l1 >= l && r1 <= r)
    {
        add[t] += z;        
        append(t, r1 - l1 + 1, z);  
        return;
    }
    push_down(t, r1 - l1 + 1);
    int mid = (l1 + r1) >> 1;
    if (l <= mid) Add(l, r, l1, mid, t << 1, z);
    if (r > mid) Add(l, r, mid + 1, r1, t << 1 | 1, z);
    for (int i = 1; i <= 3; i++)
        sum[i][t] = (sum[i][t << 1] + sum[i][t << 1 | 1]) % M;
}

void Mul(int l, int r, int l1, int r1, int t, int z)
{
    if (l1 >= l && r1 <= r)
    {
        mul[t] = z * mul[t] % M;
        if (add[t]) add[t] = (add[t] * z) % M;
        multiply(t, z);
        return;
    }
    push_down(t, r1 - l1 + 1);
    int mid = (l1 + r1) >> 1;
    if (l <= mid) Mul(l, r, l1, mid, t << 1, z);
    if (r > mid) Mul(l, r, mid + 1, r1, t << 1 | 1, z);
    for (int i = 1; i <= 3; i++)
        sum[i][t] = (sum[i][t << 1] + sum[i][t << 1 | 1]) % M;
}

void Color(int l, int r, int l1, int r1, int t, int z)
{
    if (l1 >= l && r1 <= r)
    {
        col[t] = z;
        mul[t] = 1;
        add[t] = 0;
        change(t, r1 - l1 + 1, z);
        return;
    }
    push_down(t, r1 - l1 + 1);
    int mid = (l1 + r1) >> 1;
    if (l <= mid) Color(l, r, l1, mid, t << 1, z);
    if (r > mid) Color(l, r, mid + 1, r1, t << 1 | 1, z);
    for (int i = 1; i <= 3; i++)
        sum[i][t] = (sum[i][t << 1] + sum[i][t << 1 | 1]) % M;
}

ll Get(int l, int r, int l1, int r1, int t, int z)
{
    if (l1 >= l && r1 <= r)
    {
        return sum[z][t];
    }
    push_down(t, r1 - l1 + 1);
    ll ans = 0;
    int mid = (l1 + r1) >> 1;
    if (l <= mid)ans = (ans + Get(l, r, l1, mid, t << 1, z)) % M;
    if (r > mid) ans = (ans + Get(l, r, mid + 1, r1, t << 1 | 1, z)) % M;
    return ans % M;
}

void init()
{
    memset(sum, 0, sizeof(sum));
    memset(col, 0, sizeof(col));
    memset(add, 0, sizeof(add));
    for (int i = 1; i <= 5 * n; i++) mul[i] = 1;

    while (m--)
    {
        int x, y, l, r;
        
        scanf("%d%d%d%d", &x, &l, &r, &y);
        if (x == 1) Add(l, r, 1, n, 1, y);
        else if (x == 2) Mul(l, r, 1, n, 1, y);
        else if (x == 3) Color(l, r, 1, n, 1, y);
        else  printf("%lld\n", Get(l, r, 1, n, 1, y));
    }
}

int main()
{
    while (~scanf("%d%d", &n, &m))
    {
        if (!n && !m) break;
        init();
    }
    return 0;
}
/* 对拍了一千组数据找到的唯一 一组bug
9 10
1 1 2 10
2 1 5 5
4 1 7 2
4 3 4 3
4 7 8 1
2 3 4 3
4 2 3 3
2 1 4 2
1 2 7 3
3 1 2 9
*/

猜你喜欢

转载自www.cnblogs.com/seast90/p/9382244.html