E. Monotonic Renumeration
time limit per test
2 seconds
memory limit per test
256 megabytes
input
standard input
output
standard output
You are given an array aa consisting of nn integers. Let's denote monotonic renumeration of array aa as an array bb consisting of nn integers such that all of the following conditions are met:
- b1=0b1=0;
- for every pair of indices ii and jj such that 1≤i,j≤n1≤i,j≤n, if ai=ajai=aj, then bi=bjbi=bj (note that if ai≠ajai≠aj, it is still possible that bi=bjbi=bj);
- for every index i∈[1,n−1]i∈[1,n−1] either bi=bi+1bi=bi+1 or bi+1=bi+1bi+1=bi+1.
For example, if a=[1,2,1,2,3]a=[1,2,1,2,3], then two possible monotonic renumerations of aa are b=[0,0,0,0,0]b=[0,0,0,0,0] and b=[0,0,0,0,1]b=[0,0,0,0,1].
Your task is to calculate the number of different monotonic renumerations of aa. The answer may be large, so print it modulo 998244353998244353.
Input
The first line contains one integer nn (2≤n≤2⋅1052≤n≤2⋅105) — the number of elements in aa.
The second line contains nn integers a1,a2,…,ana1,a2,…,an (1≤ai≤1091≤ai≤109).
Output
Print one integer — the number of different monotonic renumerations of aa, taken modulo 998244353998244353.
Examples
input
Copy
5 1 2 1 2 3
output
Copy
2
input
Copy
2 100 1
output
Copy
2
input
Copy
4 1 3 3 7
output
Copy
4
简单分析:b[0]=0 b[i]=b[i+1] or b[i]+1=b[i+1] 很容易想到b是一个非递减序列,首先假如没有“若a[i]=a[j]则b[i]=b[j]”的条件,
那么ans=2^(n-1) 其中n为序列长度。
现在考虑“若a[i]=a[j]则b[i]=b[j]”
给定序列 1 2 1 2 显然只有一种情况 0 0 0 0
很容易想到,两个相同数字为一个区间,区间中所有元素都相同,那么就变成了一个区间合并的问题。
这里我用的差分,应该还有其他更好的方法其实是我不会
用差分计算每个点被覆盖次数,例如1 2 1 2 ,覆盖就是1 2 2 1
那么每个区间的边界一定是1,从1到n扫一遍,记录区间端点,最后求2^(sum/2-1)即可。
注意一下细节即可。
#include "bits/stdc++.h"
using namespace std;
const int inf = 0x3f3f3f3f;
const int mod = 998244353 ;
long long qk(long long a,long long n)
{
long long ans=1;
while(n)
{
if(n&1)ans=ans*a%mod;
n>>=1;
a=a*a%mod;
}
return ans;
}
struct node
{
int w,id;
bool friend operator < (node a,node b)
{
if(a.w==b.w)return a.id<b.id;
else return a.w<b.w;
}
}a[200004];
int b[200004];
int c[200004];
int t[200004];
int main()
{
int n;
scanf("%d",&n);
memset(t,0, sizeof(t));
for (int i = 1; i <= n; ++i) {
scanf("%d",&a[i].w);
a[i].id=i;
b[i]=a[i].w;
}
sort(a+1,a+1+n);
for (int i = 1; i <= n; ++i) {
c[i]=a[i].w;
}
int l=1,r=1;
while(l<=n)
{
r=upper_bound(c+l,c+n+1,c[l])-c;
r--;
t[a[l].id]++;
t[a[r].id+1]--;
l=r+1;
}
int sum=0;
int ans=0;
for (int i = 1; i <= n; ++i) {
sum+=t[i];
t[i]=sum;
}
t[n+1]=1;
t[0]=1;
b[0]=-1;
b[n+1]=-1;
for (int i = 1; i <= n; ++i) {
if(t[i]==1){//端点的第一个条件
if(b[i]!=b[i+1]&&t[i+1]==1)ans++;//后端点
else if(b[i]!=b[i-1]&&t[i-1]==1)ans++;//前端点
}
if(t[i]==1&&t[i+1]==1&&t[i-1]==1&&b[i]!=b[i+1]&&b[i]!=b[i-1])ans++;//长度为1的区间
}
printf("%d\n",qk(2,ans/2-1));
}