【BZOJ4503】两个串

【题目链接】

【思路要点】

  • 补档博客,无题解。

【代码】

#include<bits/stdc++.h>
using namespace std;
#define MAXN	400005
#define P	acos(-1)
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
struct point {double r, i; };
point operator + (point a, point b) {return (point) {a.r + b.r, a.i + b.i}; }
point operator - (point a, point b) {return (point) {a.r - b.r, a.i - b.i}; }
point operator * (point a, point b) {return (point) {a.r * b.r - a.i * b.i, a.r * b.i + b.r * a.i}; }
int N, Log, pr[MAXN];
point a[MAXN], b[MAXN], res[MAXN];
void BRCinit() {
	for (int i = 0; i < N; i++) {
		int tmp = i, ans = 0;
		for (int j = 1; j <= Log; j++) {
			ans <<= 1;
			ans += tmp & 1;
			tmp >>= 1;
		}
		pr[i] = ans;
	}
}
void BRC(point *a) {
	for (int i = 0; i < N; i++)
		if (pr[i] > i) swap(a[i], a[pr[i]]);
}
void FFT(point *a, int type) {
	BRC(a);
	for (int len = 2, half = 1; len <= N; len <<= 1, half <<= 1) {
		point delta = (point) {cos(type * 2 * P / len), sin(type * 2 * P / len)};
		for (int start = 0; start < N; start += len) {
			point now = (point) {1, 0};
			for (int i = start, j = start + half; i < start + half; i++, j++) {
				point tmp = a[i];
				point tnp = a[j] * now;
				a[i] = tmp + tnp;
				a[j] = tmp - tnp;
				now = now * delta;
			}
		}
	}
	if (type == -1) {
		for (int i = 0; i < N; i++)
			a[i].r /= N;
	}
}
double index(char c) {
	if (c == '?') return 0;
	else return c - 'a' + 1;
}
char s[MAXN], t[MAXN];
int main() {
	scanf("\n%s", s);
	scanf("\n%s", t);
	int ls = strlen(s);
	int lt = strlen(t);
	reverse(t, t + lt);
	N = 1, Log = 0;
	while (N <= ls + lt) {
		N <<= 1;
		Log++;
	}
	BRCinit();
	for (int i = 0; i < ls; i++)
		a[i] = (point) {index(s[i]) * index(s[i]), 0};
	for (int i = ls; i < N; i++)
		a[i] = (point) {0, 0};
	for (int i = 0; i < lt; i++)
		b[i] = (point) {index(t[i]), 0};
	for (int i = lt; i < N; i++)
		b[i] = (point) {0, 0};
	FFT(a, 1); FFT(b, 1);
	for (int i = 0; i < N; i++)
		res[i] = res[i] + a[i] * b[i];
	
	for (int i = 0; i < ls; i++)
		a[i] = (point) {2 * index(s[i]), 0};
	for (int i = ls; i < N; i++)
		a[i] = (point) {0, 0};
	for (int i = 0; i < lt; i++)
		b[i] = (point) {index(t[i]) * index(t[i]), 0};
	for (int i = lt; i < N; i++)
		b[i] = (point) {0, 0};
	FFT(a, 1); FFT(b, 1);
	for (int i = 0; i < N; i++)
		res[i] = res[i] - a[i] * b[i];
	
	for (int i = 0; i < ls; i++)
		a[i] = (point) {1, 0};
	for (int i = ls; i < N; i++)
		a[i] = (point) {0, 0};
	for (int i = 0; i < lt; i++)
		b[i] = (point) {index(t[i]) * index(t[i]) * index(t[i]), 0};
	for (int i = lt; i < N; i++)
		b[i] = (point) {0, 0};
	FFT(a, 1); FFT(b, 1);
	for (int i = 0; i < N; i++)
		res[i] = res[i] + a[i] * b[i];
	FFT(res, -1);
	int ans = 0;
	for (int i = lt - 1; i < ls; i++)
		if ((int)(res[i].r + 0.5) == 0) ans++;
	printf("%d\n", ans);
	for (int i = lt - 1; i < ls; i++)
		if ((int)(res[i].r + 0.5) == 0) printf("%d\n", i - lt + 1);
	return 0;
}


猜你喜欢

转载自blog.csdn.net/qq_39972971/article/details/80384293