【XSY2429】【BZOJ4361】isn

\(Description\)

题目描述

给出一个长度为\(n\)的序列\(A(A_{1},A_{2}...A_{n})\)。如果序列\(A\)不是非降的,你必须从中删去一个数,

这一操作,直到\(A\)非降为止。求有多少种不同的操作方案,答案模\(10^9+7\)


\(Input\)

第一行一个整数\(n\)

接下来一行\(n\)个整数,描述\(A\)


\(Output\)

一行一个整数,描述答案。


\(Sample\) \(Input\)

4
1 7 5 3


\(Sample\) \(Output\)

18


\(Hint\)

\(1<=N<=2000\)

\(a_{i}\)没超\(long\) \(long\)就对了


\(Source\)

练习题 树3-树状数组


思路

我们看到这道题:求有多少种不同的操作方案,答案模\(10^9+7\)

我们首先设\(sum[i]\),表示长度为\(i\)的非降序列的全部数量

我们考虑对于每一个\(sum[i]\)

因为每一个长度为\(i\)的非降序列,它都要经过删掉另外\(n-i\)个点达到,另外\(n-i\)个点也是无序的,所以长度为所贡献答案为\(sum[i]\)

接着,我们需要考虑一个情况:对于每个长度为\(i\)的非降序列,可能是通过长度为\(i+1\)的非降序列减掉一个数得到的,但是题目要求这一操作,直到\(A\)非降为止,所以这种情况是不合法的,我们需要把它减掉

那减掉多少呢

我们考虑长度为\(i+1\)的非降序列的贡献答案为\(sum[i+1]*(n-i-1)!\),然后在这\(i+1\)个数中选择一个数删掉得到长度为\(i\)的序列,所以一共需要减掉\(sum[i+1]*(n-i-1)!*(i+1)\)

于是答案\(ans=\sum_{i=1}^nsum[i]*jc[n-i]-sum[i+1]*jc[n-i-1]*(i+1)\)

上式的\(jc[i]\)表示\(i!\),可以通过\(O(n)\)预处理出来

接着,\(sum[i]\)可以通过十分简单的\(dp\)处理出来,设\(dp[i][j]\)表示以第\(i\)个数结尾,长度为\(j\)的方案数

\(sum[j]=\sum_{i=1}^ndp[i][j]\)

然后,我们处理\(dp\)数组的时候,可以通过树状数组优化处理\(dp\)的过程


代码

#include<bits/stdc++.h>
#define lowbit(x) (x&(-x))
#define int long long
using namespace std;
const int N=2010,mod=1e9+7;
int n;
int p[N],a[N];
int dp[N][N];
int c[N][N];
int sum[N];
int jc[N];
void add(int len,int x,int y)
{
    for(;x<=n;x+=lowbit(x))c[len][x]=(c[len][x]+y)%mod;
}
int ask(int len,int x)
{
    int out=0ll;
    for(;x;x-=lowbit(x))out=(out+c[len][x])%mod;
    return out;
}
signed main()
{
    scanf("%lld",&n);
    //离散化处理
    for(int i=1;i<=n;i++)scanf("%lld",&p[i]),a[i]=p[i];
    sort(p+1,p+n+1);
    int len=unique(p+1,p+n+1)-p-1;
    add(0,1,1);
    //树状数组优化dp
    for(int i=1;i<=n;i++)
    {
        a[i]=lower_bound(p+1,p+len+1,a[i])-p;
        for(int j=i;j>=1;j--)
        {
            dp[i][j]=ask(j-1,a[i])%mod;
            add(j,a[i],dp[i][j]);
        }
    }
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            sum[j]=(sum[j]+dp[i][j])%mod;
    //处理阶乘
    jc[0]=1;//注意这里一定要设jc[0]=1,要不然当下面统计答案是当i=n-1时,答案会为0
    for(int i=1;i<=n;i++)jc[i]=jc[i-1]*i%mod;
    int ans=(ans+sum[n])%mod;//当i=n时,不能访问到jc[n-i-1],所以预先处理
    for(int i=1;i<n;i++)ans=(ans+sum[i]*jc[n-i]%mod-sum[i+1]*jc[n-i-1]%mod*(i+1)%mod+mod)%mod;
    printf("%lld",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ShuraEye/p/11619443.html