题目链接:小魂和他的数列
题意
给你一个长度为n的数列a,问数列一共有多少个长度为K的子序列是严格递增的。
题解
题目有个关键信息:1≤n≤500000,1≤k≤10。
看到k这么小我首先想到的是能否用dp去解决,不出所料,感觉还真可以。
定义dp[i][j] :以i结尾,长度为j严格递增的子序列有多少种。
状态转移: d p [ a [ k ] ] [ j ] = ∑ i = 1 i < a [ k ] d p [ i ] [ j − 1 ] {dp[a[k]][j]=\sum_{i=1}^{i<a[k]}dp[i][j-1]} dp[a[k]][j]=∑i=1i<a[k]dp[i][j−1]
答案 : ∑ i = 1 n d p [ i ] [ k ] {\sum_{i=1}^{n}dp[i][k]} ∑i=1ndp[i][k]
由于题目中a[i]≤1e9,dp数组无法开那么大,所以需要离散化。但还有一个问题转移方程的时间复杂度太大,需化简。看到前缀和,我们可以想到用树状数组去化简。
我们可以维护k个树状数组bit,bit[i]:长度为i的严格递增的序列。
很容易我们可以得出: d p [ a [ k ] ] [ j ] = ∑ i = 1 i < a [ k ] d p [ i ] [ j − 1 ] = b i t [ j − 1 ] . s u m ( a [ k ] − 1 ) {dp[a[k]][j]=\sum_{i=1}^{i<a[k]}dp[i][j-1]=bit[j-1].sum(a[k]-1)} dp[a[k]][j]=∑i=1i<a[k]dp[i][j−1]=bit[j−1].sum(a[k]−1)
代码
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<bitset>
#include<cassert>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<ctime>
#include<deque>
#include<iomanip>
#include<list>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
using namespace std;
//extern "C"{void *__dso_handle=0;}
typedef long long ll;
typedef long double ld;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define lowbit(x) x&-x
const double PI=acos(-1.0);
const double eps=1e-6;
const ll mod=998244353;
const int inf=0x3f3f3f3f;
const int maxn=5e5+10;
const int maxm=100+10;
#define ios ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
int n,k,a[maxn],dp[maxn][15],book[maxn],b[maxn];
int rem[maxn];
struct BIT{
int c[maxn];
void add(int x,int v)
{
while (x<=n) {
c[x]=(c[x]+v)%mod;
x+=lowbit(x);
}
}
int sum(int x)
{
int ans=0;
while(x>0)
{
ans=(ans+c[x])%mod;
x-=lowbit(x);
}
return ans;
}
}bit[15];
int main()
{
scanf("%d%d",&n,&k);
for(int i=0;i<=k;i++) memset(bit[i].c, 0, sizeof(bit[i].c));
for(int i=1;i<=n;i++) {
scanf("%d",&a[i]); b[i]=a[i]; }
sort(b+1,b+1+n);
int len=unique(b+1, b+1+n)-b-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1, b+1+len, a[i])-b;
for(int i=1;i<=n;i++)
{
book[a[i]]++;
dp[a[i]][1]=book[a[i]];
bit[1].add(a[i], 1);
for(int j=2;j<=k;j++)
{
dp[a[i]][j]=(dp[a[i]][j]+bit[j-1].sum(a[i]-1))%mod;
}
for(int j=2;j<=k;j++)
{
bit[j].add(a[i],bit[j-1].sum(a[i]-1));
}
}
int ans=0;
for(int i=1;i<=n;i++) ans=(ans+dp[i][k])%mod;
printf("%d\n",ans);
}