k连续子段乘积

https://ac.nowcoder.com/acm/contest/3005/C

题意:求长度为k的连续子段乘积。

解法1逆元:前缀乘积(不含0),记录前缀0的个数,并求不含0的长度为k的连续子段乘积中取最大的。

注意0没有逆元,所以在考虑递推方法时0要特殊处理。

#include <bits/stdc++.h>
#define ME(x , y) memset(x , y , sizeof(x))
#define SC scanf
#define rep(i ,j , n) for(int i = j ; i < n ; i ++)
#define red(i , n , j) for(int i = n-1 ; i >= j ; i--)
#define INF  0x3f3f3f3f
#define mod 998244353
#define PI acos(-1)
#define lson k<<1,l,mid
#define rson k<<1|1,mid+1,r
using namespace std;
typedef long long ll ;
const int MX = 1e5+9;
ll pre[200009];
ll a[200009];
int cnt0[200009];//前缀0的个数
ll quickpow(ll a , ll b){
    ll ans = 1 ;
    while(b){
        if(b&1) ans = ans%mod * a % mod;
        b >>= 1 ;
        a = a%mod * a %mod;
    }
    return ans;
}
 int main() {
 
    int n , k ;
    cin >> n >> k;
    pre[0] = 1   ;
    rep(i , 1 , n+1){
        scanf("%lld" , &a[i]);
    }
    rep(i , 1 , n+1){
        if(a[i] == 0){
            pre[i] = pre[i-1];
            cnt0[i] = cnt0[i-1]+1;
        }else{
            pre[i] = (pre[i-1]*a[i])%mod;
            cnt0[i] = cnt0[i-1];
        }
    }
 
    ll ans = 0;
    rep(i , k , n+1){
        int j = i-k+1;
        if(cnt0[i] == cnt0[j-1]){//长度k连续子段不含0
            ll temp = pre[i]*quickpow(pre[j-1] , mod-2)%mod;
            ans = max(ans , temp);
        }
    }
    cout << ans << endl;
    return 0 ;
 }

 解法2:线段树O(nlogn)求每一段连续k乘积

#include <bits/stdc++.h>
#define ME(x , y) memset(x , y , sizeof(x))
#define SC scanf
#define rep(i ,j , n) for(int i = j ; i < n ; i ++)
#define red(i , n , j) for(int i = n-1 ; i >= j ; i--)
#define INF  0x3f3f3f3f
#define mod 998244353
#define PI acos(-1)
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
using namespace std;
typedef long long ll ;
const int maxn = 2e5+9;
ll val[maxn<<2];

void build(int l , int r , int root){
    if(l == r){
        scanf("%lld" , &val[root]);
        return ;
    }
    int mid = (l + r)>>1;
    build(lson);
    build(rson);
    val[root] = val[root<<1]%mod * val[root<<1|1]%mod ;
}

ll query(int l , int r , int root , int L , int R){
    if(l >= L && r <= R){
        return val[root]%mod;
    }
    ll sum = 1 ;
    int mid = (l+r) >> 1;
    if(L <= mid)
        sum = sum%mod * query(lson , L , R)%mod;
    if(mid < R)
        sum = sum%mod * query(rson , L , R)%mod;
    return sum%mod ;
}

int main()
{
    int n , k ;
    cin >> n >> k;
    build(1 , n , 1);
    ll ans = 0 ;
    rep(i , k , n+1){
        ans = max(ans , query(1 , n , 1 , i-k+1 , i));
    }
    cout << ans << endl;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/nonames/p/12297546.html