链式反应

链式反应

题目描述

不想看题面,其实就是给定 p ( x ) p(x) ,有 f ( x ) = f ( x ) 2 p ( x ) + 1 f'(x)=f(x)^2p(x)+1 ,求 f f 这个多项式的幂级数形式前 n n 项。

Solution

式子可以写成 f n = i , j [ 0 < = i + j < n ] f i f j p n i j 1 f_n=\sum_{i,j}[0<=i+j<n]f_if_jp_{n-i-j-1}

显然能分治 F F T FFT
设分治区间为 [ l , r ] [l,r] ,考虑 [ l , m i d ] [l,mid] [ m i d + 1 , r ] [mid+1,r] 的贡献。
l = 1 l=1 时,贡献为
i = 0 m i d j = 0 m i d k = 0 r f i f j p k \sum^{mid}_{i=0}\sum^{mid}_{j=0}\sum^{r}_{k=0}f_if_jp_k
l > 1 l>1 时,有 2 l > r 2l>r ,因此贡献为
2 i = l m i d j = 0 r l k = 0 r l f i f j p k 2\sum^{mid}_{i=l}\sum^{r-l}_{j=0}\sum^{r-l}_{k=0}f_if_jp_k
时间复杂度两只 l g lg
注意这里的 f , p f,p 始终是一个 E G F EGF

#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <string>
#include <cstring>
#include <ctime>
#include <cassert>
#include <string.h>
//#include <unordered_set>
//#include <unordered_map>
//#include <bits/stdc++.h>

#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se second

using namespace std;

template<typename T>inline bool upmin(T &x,T y) { return y<x?x=y,1:0; }
template<typename T>inline bool upmax(T &x,T y) { return x<y?x=y,1:0; }

typedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int,int> PR;
typedef vector<int> VI;

const lod eps=1e-11;
const lod pi=acos(-1);
const int oo=1<<30;
const ll loo=1ll<<62;
const int mods=998244353;
const int g=3;
const int gi=(mods+1)/3;
const int MAXN=2000005;
const int INF=0x3f3f3f3f;//1061109567
/*--------------------------------------------------------------------*/
inline int read()
{
	int f=1,x=0; char c=getchar();
	while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
	while (c>='0'&&c<='9') { x=(x<<3)+(x<<1)+(c^48); c=getchar(); }
	return x*f;
}
int Limit,L;
char st[MAXN];
int F[MAXN],G[MAXN],P[MAXN],H[MAXN],Q[MAXN],rev[MAXN],inv[MAXN],fac[MAXN],n;
inline int upd(int x,int y) { return (x+y>=mods)?x+y-mods:x+y; }
inline int quick_pow(int x,int y)
{
	int ret=1;
	for (;y;y>>=1)
	{
		if (y&1) ret=1ll*ret*x%mods;
		x=1ll*x*x%mods;
	}
	return ret;
}
inline void Init(int n)
{
	fac[0]=1;
	for (int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mods;
	inv[n]=quick_pow(fac[n],mods-2);
	for (int i=n-1;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mods;
}
void Number_Theoretic_Transform(int *A,int opt)
{
	for (int i=0;i<Limit;i++) if (i<rev[i]) swap(A[i],A[rev[i]]);
	for (int i=1;i<Limit;i<<=1)
	{
		int Wn=quick_pow(opt==1?g:gi,(mods-1)/(i<<1));
		for (int j=0;j<Limit;j+=(i<<1))
			for (int k=j,w=1;k<j+i;k++,w=1ll*w*Wn%mods)
			{
				int x=A[k],y=1ll*A[k+i]*w%mods;
				A[k]=upd(x,y),A[k+i]=upd(x,mods-y);
			}
	}
	if (opt==-1)
	{
		int invlim=quick_pow(Limit,mods-2);
		for (int i=0;i<Limit;i++) A[i]=1ll*A[i]*invlim%mods;
	}
}
void solve(int l,int r)
{
	if (l==r) 
	{ 
		if (l==1) F[l]=(mods+1)>>1;
		else F[l]=1ll*F[l]*inv[l]%mods*fac[l-1]%mods;
		printf("%d\n",2ll*F[l]*fac[l]%mods);
		return; 
	}
	int mid=(l+r)>>1;
	solve(l,mid);
	if (l==1)
	{
		Limit=1,L=0;
		int len=mid*2+r;
		while (Limit<=len) Limit<<=1,L++;
		for (int i=0;i<Limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
		for (int i=0;i<Limit;i++) H[i]=G[i]=0;
		for (int i=0;i<=r;i++) H[i]=P[i];
		for (int i=0;i<=mid;i++) G[i]=F[i];
		
		Number_Theoretic_Transform(G,1);
		Number_Theoretic_Transform(H,1);
		for (int i=0;i<Limit;i++) G[i]=1ll*G[i]*G[i]%mods*H[i]%mods;
		Number_Theoretic_Transform(G,-1);
		for (int i=mid+1;i<=r;i++) F[i]=upd(F[i],G[i]);
	}
	else
	{
		Limit=1,L=0;
		int len=(mid-l)+(r-l)*2;
		while (Limit<=len) Limit<<=1,L++;
		for (int i=0;i<Limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
		for (int i=0;i<Limit;i++) H[i]=G[i]=Q[i]=0;
		for (int i=0;i<=r-l;i++) H[i]=P[i];
		for (int i=0;i<=mid-l;i++) G[i]=F[i+l];
		for (int i=0;i<=r-l;i++) Q[i]=F[i];
		
		Number_Theoretic_Transform(G,1);
		Number_Theoretic_Transform(Q,1);
		Number_Theoretic_Transform(H,1);
		for (int i=0;i<Limit;i++) G[i]=1ll*G[i]*Q[i]%mods*H[i]%mods;
		Number_Theoretic_Transform(G,-1);
		for (int i=mid+1;i<=r;i++) F[i]=upd(F[i],upd(G[i-l],G[i-l]));
	}
	solve(mid+1,r);
}
int main()
{
	n=read();
	Init(n);
	scanf("%s",st);
	for (int i=1;i<=n;i++) P[i]=(st[i-1]-'0')*inv[i-1];
	solve(1,n);
	return 0;
}
发布了94 篇原创文章 · 获赞 6 · 访问量 8517

猜你喜欢

转载自blog.csdn.net/xmr_pursue_dreams/article/details/104319044