牛客:子段乘积 (逆元or线段树)

传送门:
在这里插入图片描述

比赛的时候想到逆元了,就用逆元过了,但是后来又想到线段树写法,感觉线段树比逆元简单好多

题解:逆元法

根据费马小定理:a与mod互质,则 (1/a)%mod=a^(mod-2)%mod,该题刚好符合逆元的条件,mod=998244353 是个质数,所以利用逆元枚举区间值即可,时间复杂度O(n)

AC代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod=998244353;
const ll maxn=2e5+5;
ll a[maxn],l[maxn],r[maxn];
ll ksm(ll a,ll b,ll mod)
{
    ll res=1;
    while(b)
    {
        if(b&1)
            res=(res*a)%mod;
        b>>=1;
        a=(a*a)%mod;
    }
    return res%mod;
}
ll ksc(ll a,ll b,ll mod)
{
    ll res=0;
    while(b)
    {
        if(b&1)
            res=(res+a)%mod;
        b>>=1;
        a=(a+a)%mod;
    }
    return res%mod;
}
int main()
{
    ll n,k,pos=1;
    l[pos]=1;
    r[pos]=1;
    //cout<<ksc(1,0,mod);
    cin>>n>>k;
    for(int i=1; i<=n; i++)
    {
        cin>>a[i];
        if(a[i]==0)
        {
            r[pos]=i-1;
            l[++pos]=i+1;
        }
    }
 
    r[pos]=n;
    ll maxx=0,fla;
    for(int i=1; i<=pos; i++)
    {
        ll res=1,fla=0;
        for(int j=l[i]; j<=l[i]+k-1&&l[i]+k-1<=r[i];j++)
        {
            res=ksc(res,a[j],mod);
            fla=1;
        }
        if(fla)
        maxx=max(res,maxx);
        for(int j=l[i]+k; j<=r[i]; j++)
        {
            res=ksc(res,a[j],mod);
            res=ksc(res,ksm(a[j-k],mod-2,mod),mod);
            maxx=max(res,maxx);
        }
    }
    cout<<maxx<<endl;
 
}

题解:线段树解法

区间树自然不是浪得虚名的,我们只需要维护区间乘积,查询区间,利用for循环枚举所有符合条件的区间即可,时间复杂度O(nlogn)

AC代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll maxn=2e5+5;
const ll mod=998244353;
ll a[maxn];
struct node
{
    ll l,r,val;
}tr[maxn*4];
void pushup(ll k)
{
    tr[k].val=(tr[k<<1].val*tr[k<<1|1].val)%mod;
}
void build(ll k,ll l,ll r)
{
    tr[k].l=l,tr[k].r=r;
    if(l==r)
    {
        tr[k].val=a[l];
        return ;
    }
    ll mid=l+r>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    pushup(k);
}
ll ask(ll k,ll l,ll r)
{
    if(tr[k].l>=l&&tr[k].r<=r)
    {
        return tr[k].val;
    }
    ll mid=tr[k].l+tr[k].r>>1;
    if(mid>=r)
    {
        return ask(k<<1,l,r);
    }
    else if(mid<l)
    {
        return ask(k<<1|1,l,r);
    }
    else
    {
        return (ask(k<<1,l,mid)*ask(k<<1|1,mid+1,r))%mod;
    }
}
int main()
{
    ll n,k;
    cin>>n>>k;
    for(ll i=1;i<=n;i++)
        cin>>a[i];
    build(1,1,n);
    ll res=0;
    for(ll i=1;i+k-1<=n;i++)
    {
        res=max(res,ask(1,i,i+k-1));
    }
    cout<<res<<endl;
 
}
发布了222 篇原创文章 · 获赞 16 · 访问量 9734

猜你喜欢

转载自blog.csdn.net/yangzijiangac/article/details/104267726