【省内训练2018-12-23】String

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_39972971/article/details/85224610

【思路要点】

  • 考虑无问号的情况,分为两种:
    1 1 S = T S=T ,那么 A A B B 取任意字符串均可,贡献为 a r b i t r a r y = i = 1 N 2 i j = 1 N 2 j arbitrary=\sum_{i=1}^{N}2^i\sum_{j=1}^{N}2^j
    2 2 S T S\ne T ,那么要求 A A B B 均具有一个长度为 g c d ( A , B ) gcd(|A|,|B|) 的周期,此周期不需要为整周期,也不需要是最小周期,并且要求填入 A A B B S = T |S|=|T| 。可以证明,上述条件是 ( A , B ) (A,B) 合法的充要条件。
    假设 S S 中有 a a A A b b B B T T 中有 c c A A d d B B ,则 S = T a A + b B = c A + d B |S|=|T|\Leftrightarrow a|A|+b|B|=c|A|+d|B|
    因此,该情况的贡献为 i = 1 N j = 1 N [ a i + b j = c i + d j ] 2 g c d ( i , j ) \sum_{i=1}^{N}\sum_{j=1}^{N}[ai+bj=ci+dj]2^{gcd(i,j)}
  • 分几类讨论一下:
    1 1 a = c , b = d a=c,b=d ,那么贡献为 v a l u e g = i = 1 N j = 1 N 2 g c d ( i , j ) g = 1 N 2 g d = 1 N g μ ( d ) N g d 2 valueg=\sum_{i=1}^{N}\sum_{j=1}^{N}2^{gcd(i,j)}\sum_{g=1}^{N}2^g\sum_{d=1}^{\lfloor\frac{N}{g}\rfloor}\mu(d)*\lfloor\frac{N}{gd}\rfloor^2
    2 2 ( a c ) ( b d ) 0 (a-c)*(b-d)≥0 ,那么 a i + b j = c i + d j ai+bj=ci+dj 无正整数解,贡献为 0 0
    3 3 、否则我们可以得到一个形如 A = x y B   ( g c d ( x , y ) = 1 , x > y ) |A|=\frac{x}{y}|B|\ (gcd(x,y)=1,x>y) 的关系,贡献为 l i m x = i = 1 N x 2 i lim_x=\sum_{i=1}^{\lfloor\frac{N}{x}\rfloor}2^i
  • 因此,适当地预处理后,我们可以 O ( 1 ) O(1) 处理无问号的情况。
  • 考虑枚举 S , T S,T 中各有多少问号变为了 A A ,用组合数计算系数,可以得到一个 O ( S T ) O(|S|*|T|) 的解法。
  • 注意到计算贡献时,我们只关心 a c , b d a-c,b-d ,因此我们也可以只枚举 S S 中变为 A A 的问号比 T T 中变为 A A 的问号多几个。记 S S 中问号个数为 s s T T 中问号个数为 t t ,则 S S 中变为 A A 的问号比 T T 中变为 A A 的问号多 i i 个的方案数为 ( s + t i + t ) \binom{s+t}{i+t}
  • 注意考虑文章开头提到的 S = T S=T 的情况。
  • 时间复杂度 O ( N L o g N + S + T ) O(NLogN+|S|+|T|)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 6e5 + 5;
const int P = 1e9 + 7;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
int fac[MAXN], inv[MAXN], bit[MAXN];
int n, sa, sb, sq, ta, tb, tq, ans;
int arbitrary, valueg, lim[MAXN], miu[MAXN];
char s[MAXN], t[MAXN];
int power(int x, int y) {
	if (y == 0) return 1;
	int tmp = power(x, y / 2);
	if (y % 2 == 0) return 1ll * tmp * tmp % P;
	else return 1ll * tmp * tmp % P * x % P;
}
void update(int &x, int y) {
	x += y;
	if (x >= P) x -= P;
}
int getc(int x, int y) {
	if (y > x) return 0;
	else return 1ll * fac[x] * inv[y] % P * inv[x - y] % P;
}
void gets(char *s, int &a, int &b, int &q) {
	scanf("%s", s + 1);
	int len = strlen(s + 1);
	for (int i = 1; i <= len; i++)
		if (s[i] == 'A') a++;
		else if (s[i] == 'B') b++;
		else q++;
}
void init(int n) {
	fac[0] = bit[0] = 1;
	for (int i = 1; i <= n; i++) {
		fac[i] = 1ll * fac[i - 1] * i % P;
		bit[i] = 2ll * bit[i - 1] % P;
	}
	inv[n] = power(fac[n], P - 2);
	for (int i = n - 1; i >= 0; i--)
		inv[i] = inv[i + 1] * (i + 1ll) % P;
}
void solve(int x, int y, int coef) {
	if (x == 0 && y == 0) update(ans, 1ll * coef * valueg % P);
	if (1ll * x * y >= 0) return;
	x = abs(x), y = abs(y); int g = __gcd(x, y);
	x /= g, y /= g; if (x > y) swap(x, y);
	update(ans, 1ll * lim[y] * coef % P);
}
void equal(int len) {
	int coef = 1;
	for (int i = 1; i <= len; i++) {
		if (s[i] != '?' && t[i] != '?' && s[i] != t[i]) return;
		if (s[i] == '?' && t[i] == '?') coef = 2ll * coef % P;
	}
	update(ans, 1ll * coef * arbitrary % P);
	update(ans, P - 1ll * coef * valueg % P);
}
int getmiu(int x) {
	int i = 2, ans = 1;
	while (i * i <= x) {
		if (x % i == 0) {
			x /= i;
			if (x % i == 0) return 0;
			ans = P - ans;
		}
		i++;
	}
	if (x != 1) ans = P - ans;
	return ans;
}
void calcconsts() {
	for (int i = 1; i <= n; i++)
	for (int j = 1; i * j <= n; j++)
		update(lim[i], bit[j]);
	int sum = 0;
	for (int i = 1; i <= n; i++)
		update(sum, bit[i]);
	for (int i = 1; i <= n; i++)
		update(arbitrary, 1ll * sum * bit[i] % P);
	for (int i = 1; i <= n; i++)
		miu[i] = getmiu(i);
	for (int i = 1; i <= n; i++)
	for (int j = 1; i * j <= n; j++)
		update(valueg, 1ll * bit[i] * miu[j] % P * (n / i / j) % P * (n / i / j) % P);
}
int main() {
	freopen("string.in", "r", stdin);
	freopen("string.out", "w", stdout);
	init(6e5);
	gets(s, sa, sb, sq);
	gets(t, ta, tb, tq);
	read(n), calcconsts();
	for (int i = -tq; i <= sq; i++)
		solve(sa - ta + i, sb + sq - i - tb - tq, getc(tq + sq, i + tq));
	if (sa + sb + sq == ta + tb + tq) equal(sa + sb + sq);
	writeln(ans);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39972971/article/details/85224610
今日推荐