D-Double Strings 2021牛客多校5

D-Double Strings

题意

给两个字符串 A A A B B B ( 1 ≤ ∣ A ∣ , ∣ B ∣ ≤ 5000 ) (1\le |A|,|B| \le 5000) (1A,B5000)​ , a a a A A A 的子序列, b b b B B B 的子序列,求有多少个子序列组合满足两个子序列的长度相同并且 ∃ i ∈ { 1 , 2 , … , ∣ a ∣ } , A a i < B b i , ∀ j ∈ { 1 , 2 , … , i − 1 } , A a j = B b j \exists i \in \{1, 2, \dots, |a|\},A_{ai} < B_{bi},\forall j \in \{1, 2, \dots, i - 1\},A_{aj}=B_{bj} i{ 1,2,,a},Aai<Bbi,j{ 1,2,,i1},Aaj=Bbj​ 。

题解

  • 可以把子序列分成三段 ,第一段两个子序列完全相同,第二段长度为 1 1 1​ 满足 A i < B i A_i<B_i Ai<Bi​ ,第三段只需要长度相同即可。
  • 二重循环遍历两个字符串,如果满足 A i < B j A_i<B_j Ai<Bj ,那么 i i i 把字符串 A A A​ 分成了前后两部分, j j j 把字符串 B B B​ 分成了前后两部分,分别求出前面部分公共子序列的数量和后面部分相同的数量即可。
  • 前面部分可以通过二维dp O ( 1 ) O(1) O(1) 转移得到;
  • 后面部分可以dp, A A A 此时剩余长度为 x x x B B B 此时剩余长度为 y y y,不妨 x ≤ y x\le y xy ∑ i = 0 x C x i ⋅ C y i = ∑ i = 0 x C x x − i ⋅ C y i = C x + y x \sum_{i = 0}^{x} C_x^i \cdot C_y^i = \sum_{i = 0}^{x} C_x^{x - i} \cdot C_y^i = C_{x + y} ^ x i=0xCxiCyi=i=0xCxxiCyi=Cx+yx​ 。

代码

#include <bits/stdc++.h>
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#define per(i, a, n) for (int i = n; i >= a; --i)
#ifdef LOCAL
#include "Print.h"
#define de(...) W('[', #__VA_ARGS__,"] =", __VA_ARGS__)
#else
#define de(...)
#endif
using namespace std;
typedef long long ll;
const int maxn = 5e3 + 5;
const int mod = 1e9 + 7;
char s[maxn], t[maxn];
int n, m;
ll dp[maxn][maxn];
ll fac[maxn * 2], inv[maxn * 2];
void add(ll &x, ll y) {
    
     if ((x += y) >= mod) x -= mod; }
void sub(ll &x, ll y) {
    
     if ((x -= y) < 0) x += mod; }
ll powmod(ll a, ll b) {
    
    
    ll ans = 1;
    while (b) {
    
    
        if (b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
void init() {
    
    
    fac[1] = fac[0] = 1;
    for (int i = 2; i < maxn * 2; ++i) fac[i] = fac[i - 1] * i % mod;
    inv[maxn * 2 - 1] = powmod(fac[maxn * 2 - 1], mod - 2);
    for (int i = 2 * maxn - 2; i >= 0; --i) inv[i] = inv[i + 1] * (i + 1) % mod;
}
inline ll cal(ll a, ll b) {
    
    
    return fac[a] * inv[b] % mod * inv[a - b] % mod;
}
void DP() {
    
    
    rep(i, 0, maxn - 1) dp[i][0] = dp[0][i] = 1;
    rep(i, 1, n) rep(j, 1, m) {
    
    
        add(dp[i][j], dp[i - 1][j] + dp[i][j - 1]);
        if (s[i] != t[j]) sub(dp[i][j], dp[i - 1][j - 1]);
    }
}
int case_Test() {
    
    
    scanf("%s%s", s + 1, t + 1);
    n = strlen(s + 1), m = strlen(t + 1);
    init(), DP();
    ll ans = 0;
    rep(i, 1, n) rep(j, 1, m) if (s[i] < t[j])
        add(ans, dp[i - 1][j - 1] * cal(n + m - i - j, min(n - i, m - j)) % mod);
    printf("%lld\n", ans);
    return 0;
}
int main() {
    
    
#ifdef LOCAL
    freopen("in.in", "r", stdin);
    freopen("out.out", "w", stdout);
    clock_t start = clock();
#endif
    int _ = 1;
    // scanf("%d", &_);
    while (_--) case_Test();
#ifdef LOCAL
    printf("Time used: %.3lfs\n", (double)(clock() - start) / CLOCKS_PER_SEC);
#endif
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_43860866/article/details/119281707