【AtCoder】ARC100 F - Colorful Sequences

题解

我不会数数啊QAQ

先求出所有的序列里M这一段出现的次数的总和
答案是\((N - M + 1)K^{N - M}\)
然后求M这一段出现在不多彩的序列里次数的总和

如果M已经是多彩的了,那么答案是0

如果M不是多彩的且没有重复的数字
那么求所有N长的序列里M长含有不同数字的连续子段有多少个,答案除上\(\frac{K!}{(K - M)!}\)
那么记录dp[i][j]作为第i个,然后前j个数都是互不相同的数,j+1开始出现重复
更新的时候从dp[i - 1][h]更新
\(\left\{\begin{matrix} 1 & h \geq j \\ K - h & h = j - 1\\ 0 & h < j - 1 \end{matrix}\right.\)
然后用前缀和处理可以做到\(O(NK)\)
用cnt[i][j]表示i长的序列里,倒数j个数都是互不相同的数,M长含有不同数字的连续子段有多少个
在每遇到一个dp[i][j]的j>=M就累加上
剩下转移方式类似

如果M是多彩的且有重复数字
那么就记录F为前缀最多到哪是互不相同的,B为后缀最多到哪是互不相同的
枚举M所在位置用类似的dp转移

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
//#define ivorysi
#define fi first
#define se second
#define MAXN 25005
#define enter putchar('\n')
#define space putchar(' ')
typedef long long ll;
using namespace std;
template <class T>
void read(T &res) {
    res = 0;char c = getchar();T f = 1;
    while(c < '0' || c > '9') {
        c = getchar();
        if(c == '-') f = -1;
    }
    while(c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = getchar();
    }
    res *= f;
}
template <class T>
void out(T x) {
    if(x < 0) {x = -x;}
    if(x >= 10) {
        out(x / 10);
    }
    putchar('0' + x % 10);
}
const int MOD = 1000000007;
int N,K,M;
int A[MAXN],fac[MAXN],invfac[MAXN],inv[MAXN];
int F,B,L,vis[405];
int dp[MAXN][405],cnt[MAXN][405],sum[405],sum_cnt[405],f[MAXN],b[MAXN];
int mul(int a,int b) {
    return 1LL * a * b % MOD;
}
int inc(int a,int b) {
    return a + b >= MOD ? a + b - MOD : a + b;
}
int fpow(int x,int c) {
    int res = 1,t = x;
    while(c) {
        if(c & 1) res = mul(res,t);
        t = mul(t,t);
        c >>= 1;
    }
    return res;
}
void Init() {
    read(N);read(K);read(M);
    for(int i = 1 ; i <= M ; ++i) read(A[i]);
    inv[1] = 1;
    for(int i = 2 ; i <= K ; ++i) inv[i] = mul(inv[MOD % i],MOD - MOD / i);
    fac[0] = invfac[0] = 1;
    for(int i = 1 ; i <= K ; ++i) {
        fac[i] = mul(fac[i - 1],i);
        invfac[i] = mul(invfac[i - 1],inv[i]);
    }
    F = 0;B = 0;
    memset(vis,0,sizeof(vis));
    for(int i = 1 ; i <= M ; ++i) {
        if(!vis[A[i]]) {
            ++F;
            vis[A[i]] = 1;
        }
        else break;
    }
    memset(vis,0,sizeof(vis));
    for(int i = M ; i >= 1 ; --i) {
        if(!vis[A[i]]) {
            ++B;
            vis[A[i]] = 1;
        }
        else break;
    }
    memset(vis,0,sizeof(vis));
    int l = 0;
    for(int i = 1 ; i <= M ; ++i) {
        l = max(l,vis[A[i]]);
        L = max(L,i - l);
        vis[A[i]] = i;
    }
}
void Process(int st,int *a) {
    memset(dp,0,sizeof(dp));
    dp[0][st] = 1;
    memset(sum,0,sizeof(sum));
    for(int i = st ; i <= K ; ++i) sum[i] = 1;
    a[0] = 1;
    for(int i = 1 ; i <= N ; ++i) {
        for(int j = 1 ; j < K ; ++j) {
            dp[i][j] = inc(dp[i][j],mul(dp[i - 1][j - 1],(K - j + 1)));
            dp[i][j] = inc(dp[i][j],inc(sum[K],MOD - sum[j - 1]));
        }
        for(int j = 1 ; j <= K ; ++j) {
            sum[j] = inc(sum[j - 1],dp[i][j]);
        }
        a[i] = sum[K - 1];
    }
}
void Solve() {
    if(L == K) {
        out(mul(N - M + 1,fpow(K,N - M)));enter;
    }
    else if(F == M) {
        dp[0][0] = 1;
        int ans = mul(N - M + 1,fpow(K,N - M)),tmp = 0;
        for(int i = 1 ; i <= N ; ++i) {
            for(int j = 1 ; j < K ; ++j) {
                dp[i][j] = inc(dp[i][j],mul(dp[i - 1][j - 1],(K - j + 1)));
                dp[i][j] = inc(dp[i][j],inc(sum[K],MOD - sum[j - 1]));
                cnt[i][j] = inc(cnt[i][j],mul(cnt[i - 1][j - 1],(K - j + 1)));
                cnt[i][j] = inc(cnt[i][j],inc(sum_cnt[K],MOD - sum_cnt[j - 1]));
                if(j >= M) cnt[i][j] = inc(cnt[i][j],dp[i][j]);
            }
            for(int j = 1 ; j <= K ; ++j) {
                sum[j] = inc(sum[j - 1],dp[i][j]);
                sum_cnt[j] = inc(sum_cnt[j - 1],cnt[i][j]);
            }
        }
        for(int j = 0 ; j < K ; ++j) {
            tmp = inc(tmp,cnt[N][j]);
        }
        tmp = mul(tmp,fpow(mul(fac[K],invfac[K - M]),MOD - 2));
        ans = inc(ans,MOD - tmp);
        out(ans);enter;
    }
    else {
        Process(F,f);Process(B,b);
        int ans = mul(N - M + 1,fpow(K,N - M));
        for(int i = 1 ; i <= N - M + 1 ; ++i) {
            int j = i + M - 1;
            ans = inc(ans,MOD - mul(f[i - 1],b[N - j]));
        }
        out(ans);enter;
    }
}
int main() {
#ifdef ivorysi
    freopen("f1.in","r",stdin);
#endif
    Init();
    Solve();
}

猜你喜欢

转载自www.cnblogs.com/ivorysi/p/9439117.html