HDU 6348 序列计数 (树状数组 + DP)

序列计数

Time Limit: 4500/4000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 348    Accepted Submission(s): 117


Problem Description
度度熊了解到, 1,2,…,n 的排列一共有 n!=n×(n1)××1 个。现在度度熊从所有排列中等概率随机选出一个排列 p1,p2,…,pn,你需要对 k=1,2,3,…,n 分别求出长度为 k 的上升子序列个数,也就是计算满足 1a1 < a2 < … < ak n 且 pa1 <pa2< … < pak 的 k 元组 (a1,a2,…,ak) 的个数。

由于结果可能很大,同时也是为了 ruin the legend, 你只需要输出结果对 1000000007(=109+7) 取模后的值。
 
Input
第一行包含一个整数  T,表示有 T 组测试数据。

接下来依次描述 T 组测试数据。对于每组测试数据:

第一行包含一个整数 n,表示排列的长度。

第二行包含 n 个整数 p1,p2, …, pn,表示排列的 n 个数。

保证 1T1001n104T 组测试数据的 n 之和 105p1,p2,…,pn 是 1,2,…,n 的一个排列。

除了样例,你可以认为给定的排列是从所有 1,2,…,n 的排列中等概率随机选出的。
 
Output
对于每组测试数据,输出一行信息 "Case #x:  c1 c2 ... cn"(不含引号),其中 x 表示这是第 x 组测试数据,ci 表示长度为 i 的上升子序列个数对 1000000007(=109+7) 取模后的值,相邻的两个数中间用一个空格隔开,行末不要有多余空格。
 
Sample Input
2 4 1 2 3 4 4 1 3 2 4
 
Sample Output
Case #1: 4 6 4 1 Case #2: 4 5 2 0
 
Source
 
Recommend
chendu
 
Statistic |  Submit |  Discuss |  Note

析:当时这个题目,我的第一感觉就是 LIS 加组合数,然后就是枚举长度为 i 的上升子序列有多少个,然后可以再枚举每个数,计算以第 j 个数为结束的上升子序列有多少个,这个是可以递推的,然后计算前面有多少个数比第 j 个数小,并且长度为 i - 1 的上升子序列有多少个,这个复杂度是O(n^3),怎么可能过呢,但是题目说了,这个排序是随机给的,虽然我不知道 LIS 最长是多少,但是肯定不会很大,因为我们平时做的题目时间要算最坏的是因为,后台基本是会有最坏的数据的,毕竟出题人要卡你时间么,这个题目说了是随机的,所以这个时间复杂度可能是在O(n^(5/2)) 左右吧(猜的),这样还是过不了的,但是还可以进行优化,在求长度为 i - 1 的上升子序列有多少个的时候,可以使用树状数组来进行优化,当然其他数据结构也是可以啦,现在复杂度应该就是O(n(3/2)*log(n)),最坏的话 1e4 * 1e2 * 14 左右,大约 1e7 ,时间上差不多。交上去一遍就过了。。。

代码如下:

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <cstdio>
#include <string>
#include <cstdlib>
#include <cmath>
#include <iostream>
#include <cstring>
#include <set>
#include <queue>
#include <algorithm>
#include <vector>
#include <map>
#include <cctype>
#include <cmath>
#include <stack>
#include <sstream>
#include <list>
#include <assert.h>
#include <bitset>
#include <numeric>
#define debug() puts("++++")
#define gcd(a, b) __gcd(a, b)
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define fi first
#define se second
#define pb push_back
#define sqr(x) ((x)*(x))
#define ms(a,b) memset(a, b, sizeof a)
#define sz size()
#define be begin()
#define ed end()
#define pu push_up
#define pd push_down
#define cl clear()
#define lowbit(x) -x&x
//#define all 1,n,1
#define FOR(i,n,x)  for(int i = (x); i < (n); ++i)
#define freopenr freopen("in.in", "r", stdin)
#define freopenw freopen("out.out", "w", stdout)
using namespace std;
 
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> P;
const int INF = 0x3f3f3f3f;
const LL LNF = 1e17;
const double inf = 1e20;
const double PI = acos(-1.0);
const double eps = 1e-8;
const int maxn = 1e4 + 20;
const int maxm = 1e6 + 10;
const int mod = 1000000007;
const int dr[] = {-1, 1, 0, 0, 1, 1, -1, -1};
const int dc[] = {0, 0, 1, -1, 1, -1, 1, -1};
const char *de[] = {"0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001", "1010", "1011", "1100", "1101", "1110", "1111"};
int n, m;
const int mon[] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31};
const int monn[] = {0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31};
inline bool is_in(int r, int c) {
  return r >= 0 && r < n && c >= 0 && c < m;
}
inline int readInt(){ int x;  scanf("%d", &x);  return x; }

int sum[2][maxn];

void add(int i, int x, LL c){
  while(x <= n){
    sum[i][x] += c;
    if(sum[i][x] >= mod)  sum[i][x] -= mod;
    x += lowbit(x);
  }
}

int query(int i, int x){
  int ans = 0;
  while(x){
    ans += sum[i][x];
    if(ans >= mod)  ans -= mod;
    x -= lowbit(x);
  }
  return ans;
}

int dp[maxn], a[maxn];


int main(){
  int T;  cin >> T;
  for(int kase = 1; kase <= T; ++kase){
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i){
      scanf("%d", a + i);
      dp[i] = 1;
    }
    printf("Case #%d:", kase);
    int cur = 0;
    printf(" %d", n);
    cur = 1;
    for(int i = 2; i <= n; ++i, cur ^= 1){
      ms(sum[cur^1], 0);
      int ans = 0;
      for(int j = 1; j <= n; ++j){
        int tmp = query(cur^1, a[j]);
        add(cur^1, a[j], dp[j]);
        dp[j] = tmp;
        ans += tmp;
        if(ans >= mod)  ans -= mod;
      }
      if(ans == 0){
        for(int j = i; j <= n; ++j)  printf(" 0");
        break;
      }
      printf(" %d", ans);
    }
    printf("\n");
  }
  return 0;
}

  

猜你喜欢

转载自www.cnblogs.com/dwtfukgv/p/9483074.html