启发式分裂
给定
个数,求满足某种条件的点对数目或最大权值,而这个最大权值与点对
的区间
的区间最大
最小值有关。
那么这时就可以考虑分治,对于区间
,找到最小
大值所在位置
然后处理横跨最小/大值所在位置的点对,再递归处理子区间
对于一个区间,找到最大/最小值的位置
可以用
预处理
然后枚举
与
的较小区间,查询处理出答案即可
例题一:多校第十场 1011 Make Rounddog Happy
大意:一个好的区间定义为
并且区间内所有数互不相同,问有多少个好的区间?
题解:根据最大值进行启发式分治,
预处理区间最大值位置
为以
往左满足数字互不相同的的最远位置,
为往右,用双指针
预处理即可
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
using namespace std;
typedef long long ll;
using pii = pair <int,int>;
const int maxn = 3e5 + 5;
int T, n, k, a[maxn];
int L[maxn], R[maxn], vis[maxn];
ll ans;
int dp_max[maxn][20], pos[maxn][20];
void ST(int n, int d[]) {
for (int i=1; i<=n; i++) {
dp_max[i][0] = d[i];
pos[i][0] = i;
}
for (int j=1; (1<<j) <= n; j++) {
for (int i=1; i+(1<<j)-1 <= n; i++) {
if(dp_max[i][j-1] >= dp_max[i + (1<<(j-1))][j-1]) pos[i][j] = pos[i][j-1];
else pos[i][j] = pos[i + (1<<(j-1))][j-1];
dp_max[i][j] = max(dp_max[i][j-1], dp_max[i + (1<<(j-1))][j-1]);
}
}
}
int RMQ_max(int l, int r) {
int k = 0;
while ((1<<(k+1)) <= r-l+1) k++;
if(dp_max[l][k] >= dp_max[r - (1<<k)+1][k]) return pos[l][k];
else return pos[r - (1<<k)+1][k];
}
void solve(int l, int r) {
if(l > r) return;
int mid = RMQ_max(l, r);
if(mid - l <= r - mid) {
for(int i=l; i<=mid; i++) {
int r1 = max(mid, a[mid] - k + i - 1);
int r2 = min(r, R[i]);
if(r1 <= r2) ans += r2 - r1 + 1;
}
} else {
for(int i=mid; i<=r; i++){
int l2 = min(mid, -a[mid] + k + i + 1);
int l1 = max(l, L[i]);
if(l1 <= l2) ans += l2 - l1 + 1;
}
}
solve(l, mid - 1), solve(mid + 1, r);
}
int main() {
scanf("%d", &T);
while(T--) {
scanf("%d%d", &n, &k);
for(int i=1; i<=n; i++) scanf("%d", a+i);
vis[a[1]] = 1;
for(int i=1, p=2; i<=n; i++) {
while(p<=n && !vis[a[p]]) vis[a[p++]] = 1;
vis[a[i]] = 0, R[i] = p - 1;
}
vis[a[n]] = 1;
for(int i=n, p=n-1; i; i--) {
while(p && !vis[a[p]]) vis[a[p--]] = 1;
vis[a[i]] = 0, L[i] = p + 1;
}
ans = 0;
ST(n, a);
solve(1, n);
printf("%lld\n", ans);
}
}
例题二:洛谷P4755 Beautiful Pair
大意:问有多少个数对
题解:根据最大值进行启发式分裂,枚举
,找出区间有多少个
的数,可以用主席树
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 5;
int n, a[maxn], aid[maxn], tot, t[maxn*30];
int root[maxn*30], ls[maxn*30], rs[maxn*30];
vector <int> v;
ll ans;
int dp_max[maxn][20], pos[maxn][20];
void ST(int n, int d[]) {
for (int i=1; i<=n; i++) {
dp_max[i][0] = d[i];
pos[i][0] = i;
}
for (int j=1; (1<<j) <= n; j++) {
for (int i=1; i+(1<<j)-1 <= n; i++) {
if(dp_max[i][j-1] >= dp_max[i + (1<<(j-1))][j-1]) pos[i][j] = pos[i][j-1];
else pos[i][j] = pos[i + (1<<(j-1))][j-1];
dp_max[i][j] = max(dp_max[i][j-1], dp_max[i + (1<<(j-1))][j-1]);
}
}
}
int RMQ_max(int l, int r) {
int k = 0;
while ((1<<(k+1)) <= r-l+1) k++;
if(dp_max[l][k] >= dp_max[r - (1<<k)+1][k]) return pos[l][k];
else return pos[r - (1<<k)+1][k];
}
int getid(int x){
return lower_bound(v.begin(), v.end(), x) - v.begin() + 1;
}
void update(int &rt, int pre, int l, int r, int pos){
if(l>pos || r<pos) return;
rt = ++tot;
t[rt] = t[pre] + 1;
ls[rt] = ls[pre], rs[rt] = rs[pre];
if(l == r) return;
int m = l + r >> 1;
update(ls[rt], ls[pre], l, m, pos);
update(rs[rt], rs[pre], m+1, r, pos);
}
int query(int rt, int pre, int l, int r, int pos){
if(l > pos) return 0;
if(r <= pos) return t[rt] - t[pre];
int m = l + r >> 1, ret = 0;
ret += query(ls[rt], ls[pre], l, m, pos);
ret += query(rs[rt], rs[pre], m+1, r, pos);
return ret;
}
void solve(int l, int r){
if(l > r) return;
int pos = RMQ_max(l, r), mx = a[pos];
if(pos - l <= r - pos){
for(int i=l; i<=pos; i++){
int k = mx / a[i];
int kid = upper_bound(v.begin(), v.end(), k) - v.begin();
ans += query(root[r], root[pos-1], 1, v.size(), kid);
}
} else {
for(int i=pos; i<=r; i++){
int k = mx / a[i];
int kid = upper_bound(v.begin(), v.end(), k) - v.begin();
ans += query(root[pos], root[l-1], 1, v.size(), kid);
}
}
solve(l, pos - 1), solve(pos + 1, r);
}
int main() {
scanf("%d", &n);
for(int i=1; i<=n; i++) scanf("%d", a+i), v.push_back(a[i]);
v.push_back(1e9 + 7);
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
for(int i=1; i<=n; i++){
aid[i] = getid(a[i]);
update(root[i], root[i-1], 1, v.size(), aid[i]);
}
ST(n, a);
solve(1, n);
printf("%lld\n", ans);
}