For a set of sequences of integers{a1,a2,a3,…an}, we define a sequence{ai1,ai2,ai3…aik}in which 1<=i1<i2<i3<…<ik<=n, as the sub-sequence of {a1,a2,a3,…an}. It is quite obvious that a sequence with the length n has 2^n sub-sequences. And for a sub-sequence{ai1,ai2,ai3…aik},if it matches the following qualities: k >= 2, and the neighboring 2 elements have the difference not larger than d, it will be defined as a Perfect Sub-sequence. Now given an integer sequence, calculate the number of its perfect sub-sequence.
Input
Multiple test cases The first line will contain 2 integers n, d(2<=n<=100000,1<=d=<=10000000) The second line n integers, representing the suquence
Output
The number of Perfect Sub-sequences mod 9901
Sample Input
4 2
1 3 7 5
Sample Output
4
题意:
求所有长度至少为2的子序列,满足子序列相邻两个元素的差值不大于d
思路:
与Crazy Thairs POJ - 3378 这道题很像,这道题求的是5元上升组的个数。本题中加了一个相邻元素不大于d的限制,而且求的是大于1元上升组的和。
那道题维护了5个树状数组,分别表示5元上升序列的长度。
本题只需要维护2个就可以了,一个是一元,一个是大于一元。
转移的条件就是abs(a[i] - a[j]) ≤ d。再用二分找出j的最大值和最小值。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 7;
const int mod = 9901;
int a[maxn],b[maxn],c[3][maxn];
int n,d;
struct Node
{
int id,v;
bool operator < (const Node &rhs)const
{
return v < rhs.v;
}
}nodes[maxn];
int bin_l(int x)
{
int l = 1,r = n,val = b[x];
int ans = 0;
while(l <= r)
{
int mid = (l + r) >> 1;
if(val - nodes[mid].v <= d)
{
ans = mid;
r = mid - 1;
}
else
{
l = mid + 1;
}
}
return a[nodes[ans].id];
}
int bin_r(int x)
{
int l = 1,r = n,val = b[x];
int ans = 0;
while(l <= r)
{
int mid = (l + r) >> 1;
if(nodes[mid].v - val <= d)
{
ans = mid;
l = mid + 1;
}
else
{
r = mid - 1;
}
}
return a[nodes[ans].id];
}
void add(int num,int x,int v)
{
while(x <= n)
{
c[num][x] += v;
c[num][x] %= mod;
x += x & (-x);
}
}
int query(int num,int x)
{
int res = 0;
while(x)
{
res += c[num][x];
res %= mod;
x -= x & (-x);
}
return res;
}
int main()
{
while(~scanf("%d%d",&n,&d))
{
for(int i = 1;i <= n;i++)
{
scanf("%d",&b[i]);
nodes[i].id = i;nodes[i].v = b[i];
}
sort(nodes + 1,nodes + 1 + n);
int cnt = 0;
a[nodes[1].id] = ++cnt;
for(int i = 2;i <= n;i++)
{
if(nodes[i].v == nodes[i - 1].v)a[nodes[i].id] = cnt;
else a[nodes[i].id] = ++cnt;
}
memset(c,0,sizeof(c));
for(int i = 1;i <= n;i++)
{
int l = bin_l(i);
int r = bin_r(i);
int tmp = query(1,r) - query(1,l - 1) + query(2,r) - query(2,l - 1);
tmp = (tmp + mod) % mod;
add(1,a[i],1);add(2,a[i],tmp);
}
int ans = query(2,n);
printf("%d\n",ans);
}
return 0;
}