小魂和他的数列(dp+树状数组优化)

链接:https://ac.nowcoder.com/acm/contest/3566/C
来源:牛客网


Sometimes, even if you know how something’s going to end, that doesn’t mean you can’t enjoy the ride.
有时候,即使你知道了故事的结局,也不代表你不可以享受它的过程。

小魂和他的数列

时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 131072K,其他语言262144K
64bit IO Format: %lld

题目描述

一天,小魂正和一个数列玩得不亦乐乎。
小魂的数列一共有n个元素,第i个数为Ai。
他发现,这个数列的一些子序列中的元素是严格递增的。
他想知道,这个数列一共有多少个长度为K的子序列是严格递增的。
请你帮帮他,答案对998244353取模。
对于100%的数据,1≤ n ≤ 500,000,2≤ K ≤ 10,1≤ Ai ≤ 10^9。
输入描述:
第一行包含两个整数n,K,表示数列元素的个数和子序列的长度。
第二行包含n个整数,表示小魂的数列。
输出描述:
一行一个整数,表示长度为K的严格递增子序列的个数对998244353取模的值。

示例1

输入
复制

5 3
2 3 3 5 1

输出

2

说明

两个子序列分别是2 3 3 5 1和2 3 3 5 1。

思路:

确定状态:
dp[i][j]:以i结尾长度为j的严格上升子序列的个数。
那么有状态转移方程:
d p [ i ] [ j ] = k = 1 i 1 ( a [ k ] < a [ i ] ) d p [ k ] [ j 1 ] dp[i][j] = \sum_{k=1}^{i-1} (a[k] < a[i]) * dp[k][j-1]
根据状态转移方程可以写出:

for(int i = 1; i <= n; i++)
{
    for(int k = 1; k < i; k++)
    {
        for(int j = 1; j <= K; j++)
        {
            dp[i][j] += (a[k ] < a[i]) * dp[k][j-1];
        }
    }
}

时间复杂度: O(k*n^2)
很明显时间复杂度不达标。
此时想优化:
上面第二层循环(k < i)就是求前缀和,K最大为10,
所以可以利用树状数组来优化,开K棵树状数组,
每个树状组数组存长度为j的严格上升子序列的个数。
即:

for(int i = 1; i <= n; i++)
{
    for(int j = 1; j <= K; j++)
    {
        dp[i][j] += query(id,pos);
    }
}

时间复杂度: O(k*nlogn)
此时已经可以在规定的复杂度里求出结果。

对于每一个数插入树状数组前的处理:

第一种方式是直接排序(特别注意排序规则),不去重,具体见下面的code1
第二种是离散化(a[i]比较大,可能出现很多重复的值)+去重处理,具体见下面的code2

AC代码:

个人觉得code1更好理解

code1:

 /*
 dp[i][j]:以i结尾长度为j的严格上升子序列的个数。
 状态转移方程:dp[i][j] +=(a[j]<a[i])*dp[i-1][j-1];
 状态转移方程有约束条件:a[j]<a[i] && j<i;
 对于没有去重,这里排完序(a[j]<a[i]--->从小到大排序),
 相同大小的(没有特判,程序会统一当作a[j]<a[i]处理),
 只有破坏j<i这个条件才不会统计错误(多统计),
 所以让相同大小的按index大的优先排序。
 */
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5+5;
const int mod = 998244353;
struct Node
{
    int x;
    int pos;
    bool operator<(const Node& n)const
    {
        if(x == n.x)
        {
            return pos > n.pos;
        }
        return x < n.x;
    }
} a[N];
int n,k;
int sum[12][N];
inline int lowbit(int x)
{
    return x&-x;
}
void add(int id,int pos,int v)
{
    while(pos <= n)
    {
        sum[id][pos] = (sum[id][pos] + v) % mod;
        pos += lowbit(pos);
    }
}
int query(int id,int pos)
{
    int res = 0;
    while(pos)
    {
        res = (res + sum[id][pos]) % mod;
        pos -= lowbit(pos);
    }
    return res;
}
int main()
{
    scanf("%d%d",&n,&k);
    for(int i = 1; i <= n; i++)
    {
        scanf("%d",&a[i].x);
        a[i].pos = i;
    }
    sort(a+1,a+n+1);
    int ans = 0;
    //按输入的顺序插入树状数组
    for(int i = 1; i <= n; i++)
    {
        add(1,a[i].pos,1);
        for(int j = 2; j <= k; j++)
        {
            if(j > a[i].pos) break;//序列长度达不到j
            int v = query(j-1,a[i].pos-1);
            add(j,a[i].pos,v);
            if(j == k) ans = (ans + v) % mod;
        }
    }
    printf("%d\n",ans);
    return 0;
}

code2:


/*
如果去重(sort+unique+erase),
就要知道插入那个数它排第几
(low_bound(aim_arr_begin,aim_arr_end,x):
 返回im_arr中第一个大于等于x的地址)
*/

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5+5;
const int mod = 998244353;
int n,k;
int a[N];
vector<int>vec;
int sum[12][N];
inline int lowbit(int x)
{
    return x&-x;
}
void add(int id,int pos,int v)
{
    while(pos <= n)
    {
        sum[id][pos] = (sum[id][pos] + v) % mod;
        pos += lowbit(pos);
    }
}
int query(int id,int pos)
{
    int res = 0;
    while(pos)
    {
        res = (res + sum[id][pos]) % mod;
        pos -= lowbit(pos);
    }
    return res;
}
int main()
{

    scanf("%d%d",&n,&k);
    for(int i = 1; i <= n; i++)
    {
        scanf("%d",&a[i]);
        vec.push_back(a[i]);
    }
    //离散化处理
    sort(vec.begin(),vec.end());
    vec.erase(unique(vec.begin(),vec.end()),vec.end());
    //确定每个数的大小名次
    for(int i = 1; i <= n; i++)
    {
        //这里要注意:+1,要不然下面(a[i]-1)会数组越界,而且树状数组也是从1开始
        a[i] = lower_bound(vec.begin(),vec.end(),a[i])-vec.begin()+1;
    }
    int ans = 0;
    //这里按每个数的名次插入树状数组即可求前缀和
    for(int i = 1; i <= n; i++)
    {
        add(1,a[i],1);
        for(int j = 2; j <= k; j++)
        {
            int v = query(j-1,a[i]-1);
            add(j,a[i],v);
            if(j == k) ans = (ans + v) % mod;
        }
    }
    printf("%d\n",ans);
    return 0;
}

发布了301 篇原创文章 · 获赞 38 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/tb_youth/article/details/104333201
今日推荐