题意:
有 n n n个数字 a 1 , 2... n a_{1,2...n} a1,2...n,要分成k段,每段的价值为段内满足 [ a i = = a j ] a n d [ i = = j ] [a_i==a_j] and[i==j] [ai==aj]and[i==j]的 ( i , j ) (i,j) (i,j)对数。
( n , a i ≤ 1 e 5 , k ≤ m i n ( n , 20 ) ) (n,a_i\le 1e5,k\le min(n,20)) (n,ai≤1e5,k≤min(n,20))
解题思路:
有一个很显然的 O ( n 2 k ) O(n^2k) O(n2k)的做法, d p ( i , j ) dp(i,j) dp(i,j)表示把 a 1 , 2 , . . j a_{1,2,..j} a1,2,..j分成 i i i段的最小价值,转移为
d p ( i , j ) = m i n ( d p ( i − 1 , j ′ ) + c o s t ( j ′ , j ) ) [ j ′ < j ] dp(i,j)=min(dp(i-1,j')+cost(j',j))\quad[j'<j] dp(i,j)=min(dp(i−1,j′)+cost(j′,j))[j′<j]
这里 c o s t ( j ′ , j ) cost(j',j) cost(j′,j)表示 ( j ′ , j ] (j',j] (j′,j]这一段的价值。
考虑优化,往单调性方面去考虑:每个转移点的位置是否满足单调性?设 p j p_j pj为 d p ( i , j ) dp(i,j) dp(i,j)的最左端的最优转移点,对于 ( j ′ < j ) (j'<j) (j′<j)是否总有 p j ′ ≤ p j p_{j'}\le p_j pj′≤pj?
答案是Yes。现在用反证法证明:
证明思路来自:https://codeforces.com/blog/entry/55046
假设存在 j ′ < j j'<j j′<j,且 d p ( i − 1 , x ) + c o s t ( x , j ) < d p ( i − 1 , p j ′ ) + c o s t ( p j ′ , j ) dp(i-1,x)+cost(x,j)<dp(i-1,p_{j'})+cost(p_{j'},j) dp(i−1,x)+cost(x,j)<dp(i−1,pj′)+cost(pj′,j)满足 x < p j ′ x<p_{j'} x<pj′。
画图出来是这样:
那么显然的有:
c o s t ( x , j ) − c o s t ( p j ′ , j ) > c o s t ( x , j ′ ) − c o s t ( p j ′ , j ′ ) cost(x,j)-cost(p_{j'},j)>cost(x,j')-cost(p_{j'},j') cost(x,j)−cost(pj′,j)>cost(x,j′)−cost(pj′,j′)
由假设的条件我们得到
d p ( i − 1 , p j ′ ) − d p ( i − 1 , x ) > c o s t ( x , j ) − c o s t ( p j ′ , j ) > c o s t ( x , j ′ ) − c o s t ( p j ′ , j ′ ) dp(i-1,p_{j'})-dp(i-1,x)> cost(x,j)-cost(p_{j'},j)>cost(x,j')-cost(p_{j'},j') dp(i−1,pj′)−dp(i−1,x)>cost(x,j)−cost(pj′,j)>cost(x,j′)−cost(pj′,j′)
移项得到 d p ( i − 1 , p j ′ ) + c o s t ( p j ′ , j ′ ) > d p ( i − 1 , x ) + c o s t ( x , j ′ ) dp(i-1,p_{j'})+cost(p_{j'},j')>dp(i-1,x)+cost(x,j') dp(i−1,pj′)+cost(pj′,j′)>dp(i−1,x)+cost(x,j′)
这与 p j ′ p_{j'} pj′是 j ′ j' j′的最左端最优转移点矛盾
证毕。
那么知道了它满足单调性之后,就可以采用类似整体二分的分治去转移了,因为每一层的搜索区间总长度是 O ( n ) O(n) O(n)的,一共有 l o g log log层,所以总复杂度 O ( k n l o g n ) O(knlogn) O(knlogn)
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 + 50;
int cnt[maxn], a[maxn];
ll ans;
void add(int x){
ans += cnt[x]; cnt[x]++;
}
void del(int x){
ans -= (cnt[x]-1); cnt[x]--;
}
int lp , rp;
void go(int l, int r){
while(lp > l) add(a[--lp]);
while(rp < r) add(a[++rp]);
while(lp < l) del(a[lp++]);
while(rp > r) del(a[rp--]);
}
ll dp[21][maxn];
int n, k, t;
void sol(int l, int r, int L, int R, ll *pre, ll *cur){
if(l > r) return;
int mid = (l+r)>>1, p;
for(int i = L; i <= min(R, mid-1); ++i){
go(i+1, mid);
if(cur[mid] > pre[i] + ans) cur[mid] = pre[i] + ans, p = i;
}
sol(l, mid-1, L, p, pre, cur); sol(mid+1, r, p, R, pre, cur);
}
int main()
{
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
memset(dp, 0x3f, sizeof dp); dp[0][0] = 0;
lp = rp = 1; add(a[1]);
for(t = 1; t <= k; ++t){
dp[t][0] = 0;
sol(1, n, 0, n-1, dp[t-1], dp[t]);
}
cout<<dp[k][n]<<endl;
}
/*
20 3
1 3 2 0 1 0 2 2 2 0 1 1 1 3 1 3 3 2 3 0
*/