https://codeforces.com/contest/645/problem/E
给一个字符串,要添加 m 个字符使不同的子序列数量尽可能多,求这个数量。
首先考虑计算一个字符串的不同子序列数目,令 dp[i] 为以字母 i 结尾的不同子序列个数,那么新添加一个 j 时可以接在所有其他字母后,或者单独成一个新的子序列,于是 \(dp(i) = 1 + \sum{dp(j})\)。
额外添加字符时,注意到新加一个字符 i 时只会改变 dp[i] 的值,要使增加量最大就应该选取最小的 dp[i]。又因为 i<j 时 dp[i]<dp[j],所以应该选取上次出现位置最远的字母。容易发现每次的选取会形成循环。
复杂度 O(n)。
#include <bits/stdc++.h>
using namespace std;
typedef unsigned u32;
template <u32 MOD>
struct ModInt {
ModInt() : v(0) {}
ModInt(u32 v) : v(v) {}
u32 get() { return v; }
ModInt& operator += (const ModInt &rhs) {
if ((v += rhs.v) >= MOD) v -= MOD;
return *this;
}
ModInt& operator -= (const ModInt &rhs) {
if ((v += MOD - rhs.v) >= MOD) v -= MOD;
return *this;
}
ModInt operator + (const ModInt &rhs) const { return ModInt(*this) += rhs; }
ModInt operator - (const ModInt& rhs) const { return ModInt(*this) -= rhs; }
u32 v;
};
typedef ModInt<1000000007> Mint;
const int N = 1000010;
int m, k, n;
char s[N];
Mint dp[26];
int prd[26], ls[26];
bool cmp(const int &x, const int &y) {
return ls[x] < ls[y];
}
int main() {
freopen("string.in", "r", stdin);
freopen("string.out", "w", stdout);
scanf("%d %d", &m, &k);
scanf("%s", s);
n = strlen(s);
fill(ls, ls + k, -1);
Mint sm = Mint(1);
for (int i = 0; i < n; i++) {
int c = s[i] - 'a';
Mint odp = dp[c];
dp[c] = sm;
sm += sm - odp;
ls[c] = i;
}
for (int i = 0; i < k; i++) prd[i] = i;
sort(prd, prd + k, cmp);
for (int i = 0; i < m; i++) {
int c = prd[i % k];
Mint odp = dp[c];
dp[c] = sm;
sm += sm - odp;
}
printf("%u\n", sm.get());
return 0;
}