题解 CF718C 【Sasha and Array】

题目链接

不得不说这题是线段树维护矩阵的一道好题,此外推荐\(LibreOJ\)上的一道好题「THUSCH 2017」大魔法师 也可以用线段树维护矩阵

Solution [CF718C] Sasha and Array

题目大意:请你维护一个数列,支持一下两种操作:

\(1\).将区间\([l,r]\)内的数加上\(x\)

\(2\).求\(\sum_{i =l}^{r}f(a_{i})\),其中\(f(x)\)表示斐波那契数列的第\(x\)

做法:既然数据范围已经达到了\(n,m \leq 1e5\)级别,那么我们就得考虑\(O(nlogn)\)级别的算法了

回想一下我们是如何以\(O(logn)\)的优秀复杂度求斐波那契数列的.矩阵乘法对吧?

我们设一个矩阵\(orgin = \begin{bmatrix} 0&1\end{bmatrix}\),以及一个转移矩阵\(w = \begin{bmatrix} 1&1\\1&0 \end{bmatrix}\)

那么很显然\(f(x) = orgin\;*\;w^x\) 即对于表示\(f(x)\)的矩阵乘上\(w^k\)就求得了\(f(x + k)\)

很显然我们已经解决了操作\(1\)

对于操作\(2\),我们只需要维护区间矩阵之和:正确性显而易见,矩阵满足结合律.所以有:

\(\sum_{i = l}^rw*a_{i} = w * \sum_{i = l}^{r}\)

所以这题就是一个区间乘区间求和的模板题了,注意开\(long\;long\)就好

#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;
const int maxn = 100100;
const int mod = 1e9 + 7;
struct matrix{//矩阵模板
    ll val[4][4];
    int x,y;
    matrix operator * (const matrix &rhs)const{
        matrix ret;
        ret.x = x;
        ret.y = rhs.y;
        for(int i = 1;i <= ret.x;i++)
            for(int j = 1;j <= ret.y;j++){
                ret.val[i][j] = 0;
                for(int k = 1;k <= y;k++)
                    ret.val[i][j] += val[i][k] * rhs.val[k][j],ret.val[i][j] %= mod;
            }
        return ret;
    }
    matrix operator + (const matrix &rhs)const{
        matrix ret;
        ret.x = rhs.x;
        ret.y = rhs.y;
        for(int i = 1;i <= ret.x;i++)
            for(int j = 1;j <= ret.y;j++)
                ret.val[i][j] = (val[i][j] + rhs.val[i][j]) % mod;
        return ret;
    }
    inline void print(){
        for(int i = 1;i <= x;i++){
            for(int j = 1;j <= y;j++)
                printf("%lld ",val[i][j]);
            printf("\n");
        }
        printf("\n");
    }
}orgin,w,unit;
inline void init(){//初始化,orgin上文已经提到,w为转移矩阵,unit为单位矩阵
    orgin.x = 1;
    orgin.y = 2;
    orgin.val[1][1] = 0;
    orgin.val[1][2] = 1;
    w.x = w.y = 2;
    w.val[1][1] = 1;
    w.val[1][2] = 1;
    w.val[2][1] = 1;
    w.val[2][2] = 0;
    unit.x = unit.y = 2;
    unit.val[1][1] = 1;
    unit.val[2][2] = 1;
}
namespace ST{//线段树模板
    struct Node{
        int l,r;
        matrix mark,val;
    }tree[maxn << 2];
    #define lson (root << 1)
    #define rson (root << 1 | 1)
    inline void maintain(int root){
        tree[root].val = tree[lson].val + tree[rson].val;
    }
    inline void pushdown(int root){
        tree[lson].val = tree[lson].val * tree[root].mark;
        tree[lson].mark = tree[lson].mark * tree[root].mark;
        tree[rson].val = tree[rson].val * tree[root].mark;
        tree[rson].mark = tree[rson].mark * tree[root].mark;
        tree[root].mark = unit;
    }
    inline void build(int a,int b,int root = 1){
        tree[root].l = a;
        tree[root].r = b;
        tree[root].mark = unit;
        if(a == b){
            tree[root].val = orgin;
            return;
        }
        int mid = (tree[root].l + tree[root].r) >> 1;
        build(a,mid,lson);
        build(mid + 1,b,rson);
        maintain(root);
    }
    inline ll query(int a,int b,int root = 1){
        if(a <= tree[root].l && b >= tree[root].r)return tree[root].val.val[1][1];
        pushdown(root);
        ll ret = 0;
        int mid = (tree[root].l + tree[root].r) >> 1;
        if(a <= mid)ret += query(a,b,lson),ret %= mod;
        if(b >= mid + 1)ret += query(a,b,rson),ret %= mod;
        return ret;
    }
    inline void modify(int a,int b,const matrix val,int root = 1){
        if(a <= tree[root].l && b >= tree[root].r){
            tree[root].val = tree[root].val * val;
            tree[root].mark = tree[root].mark * val;
            return;
        }
        pushdown(root);
        int mid = (tree[root].l + tree[root].r) >> 1;
        if(a <= mid)modify(a,b,val,lson);
        if(b >= mid + 1)modify(a,b,val,rson);
        maintain(root);
    }
    #undef lson
    #undef rson
}
inline matrix power(const matrix &a,int b){//快速幂
    matrix ret = unit,base = a;
    while(b){
        if(b & 1)ret = ret * base;
        base = base * base;
        b >>= 1;
    }
    return ret;
}
int n,m;
int main(){
#ifdef LOCAL
    freopen("fafa.in","r",stdin);
#endif
    init();
    scanf("%d %d",&n,&m);
    ST::build(1,n);
    for(int x,i = 1;i <= n;i++)//实在没想到啥更好的初始化办法了,我还是tcl
        scanf("%d",&x),ST::modify(i,i,power(w,x));
    for(int i = 1;i <= m;i++){
        int opt,l,r,x;
        scanf("%d %d %d",&opt,&l,&r);
        if(opt == 1)scanf("%d",&x),ST::modify(l,r,power(w,x));
        else printf("%lld\n",ST::query(l,r));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/colazcy/p/11514996.html
今日推荐