Gym101667 H. Rock Paper Scissors

将第二个字符串改成能赢对方时对方的字符并倒序后,字符串匹配就是卷积的过程。

那么就枚举字符做三次卷积即可。

#include <bits/stdc++.h>

struct Complex {
    double r, i;
    Complex(){}
    Complex(double r, double i): r(r), i(i) {}
    Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
    Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
    Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
};

const int N = 4e5 + 7;
const double pi = acos(-1.0);
int r[N];

void FFT(Complex a[], int n, int pd) {
    for (int i = 0; i < n; i++)
        if (i < r[i])
            std::swap(a[i], a[r[i]]);
    for (int mid = 1; mid < n; mid <<= 1) {
        Complex wn(cos(pi / mid), pd * sin(pi / mid));
        for (int l = mid << 1, j = 0; j < n; j += l) {
            Complex w(1.0, 0.0);
            for (int k = 0; k < mid; k++, w = w * wn) {
                Complex u = a[k + j], v = w * a[k + j + mid];
                a[k + j] = u + v;
                a[k + j + mid] = u - v;
            }
        }
    }
    if (pd == -1)
        for (int i = 0; i < n; i++)
            a[i] = Complex(a[i].r / n, a[i].i / n);
}

Complex A[N], B[N];
int n, m, limit, sum[N];
char s[N], t[N];

void solve(char ch) {
    for (int i = 0; i < n; i++)
        A[i] = Complex(s[i] == ch ? 1 : 0, 0.0);
    for (int i = n; i < limit; i++)
        A[i] = Complex(0.0, 0.0);
    for (int i = 0; i < m; i++)
        B[i] = Complex(t[i] == ch ? 1.0 : 0.0, 0.0);
    for (int i = m; i < limit; i++)
        B[i] = Complex(0.0, 0.0);
    FFT(A, limit, 1); FFT(B, limit, 1);
    for (int i = 0; i < limit; i++)
        A[i] = A[i] * B[i];
    FFT(A, limit, -1);
    for (int i = 0; i < limit; i++)
        sum[i] += (int)(A[i].r + 0.5);
}

int main() {
    scanf("%d%d", &n, &m);
    scanf("%s", s);
    scanf("%s", t);
    for (int i = 0; i < m; i++) {
        if (t[i] == 'R') t[i] = 'S';
        else if (t[i] == 'S') t[i] = 'P';
        else t[i] = 'R';
    }
    std::reverse(t, t + m);
    int len = n + m;
    int l = 0;
    limit = 1;
    while (limit < len) limit <<= 1, l++;
    for (int i = 0; i < limit; i++)
        r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
    solve('R');
    solve('S');
    solve('P');
    int ans = 0;
    for (int i = m - 1; i < limit; i++)
        ans = std::max(ans, sum[i]);
    printf("%d\n", ans);
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Mrzdtz220/p/11910377.html