「LibreOJ NOI Round #2」不等关系 (dp+NTT分治)

description

戳我看题目哦

solution

有一道非常相似的题目

一棵树,每条边限制两个端点的大小关系(限制 a [ u ] > a [ v ] a[u]>a[v] a[u]>a[v] a [ u ] < a [ v ] a[u]<a[v] a[u]<a[v]
求有多少种符合要求的排列 a a a满足整棵树的限制。 n < = 5000 n<=5000 n<=5000

考虑如果所有边都是朝一个方向的话很好做
答案就是 n ! n! n!除以每个子树的大小
如果存在反向边的话,暴力枚举断开若干个反向边,剩下的边改为正向,然后计算答案
容斥即可。这样暴力做的复杂度是 O ( 2 n ∗ n ) O(2^n*n) O(2nn)
考虑 d p dp dp f ( i , j , k ) f(i,j,k) f(i,j,k) 表示以 i i i 为根的子树,当前 i i i 所在连通块内有 j j j 个点,总共反向 k k k 条边的方案数
合并两棵子树时,如果边是正向的,那么直接合并;
否则要么断开,要么让 k + 1 k+1 k+1 并且按照正向合并
复杂度 n n n 的若干次方
考虑最后的容斥只需要关注 k k k 的奇偶性,因此第三维完全可以省掉
即合并两棵子树时,如果边是正向则直接合并,否则值就是断开的方案减掉把边正向的方案
因此就是一个简单的树背包,复杂度 O ( n 2 ) O(n^2) O(n2)

此题只是需要将二维 d p dp dp再次优化即可
d p [ i ] dp[i] dp[i]表示前缀 i i i的合法方案数, c n t [ i ] cnt[i] cnt[i]表示前缀 i i i > > >的个数
d p [ i ] i ! = ∑ j = 0 i − 1 [ s [ j ] = ′ > ′ ] ( i − j ) ! ( − 1 ) c n t [ i − 1 ] − c n t [ j ] × d p [ j ] j ! \frac{dp[i]}{i!}=\sum_{j=0}^{i-1}\frac{[s[j]='>']}{(i-j)!}(-1)^{cnt[i-1]-cnt[j]}\times \frac{dp[j]}{j!} i!dp[i]=j=0i1(ij)![s[j]=>](1)cnt[i1]cnt[j]×j!dp[j]
( − 1 ) c n t [ i − 1 ] (-1)^{cnt[i-1]} (1)cnt[i1]提出来,剩余部分用 N T T NTT NTT分治完成
有难度
在这里插入图片描述

code

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define mod 998244353
#define int long long 
#define maxn 400005
int len, inv;
char s[maxn];
int cnt[maxn];
int fac[maxn], ifac[maxn], r[maxn];
int f[maxn], g[maxn], dp[maxn];

int qkpow( int x, int y ) {
    
    
	int ans = 1;
	while( y ) {
    
    
		if( y & 1 ) ans = ans * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return ans;
}

void NTT( int *c, int opt ) {
    
    
	for( int i = 0;i < len;i ++ )
		if( i < r[i] ) swap( c[i], c[r[i]] );
	for( int i = 1;i < len;i <<= 1 ) {
    
    
		int omega = qkpow( opt == 1 ? 3 : mod / 3 + 1, ( mod - 1 ) / ( i << 1 ) );
		for( int j = 0;j < len;j += ( i << 1 ) ) {
    
    
			int w = 1;
			for( int k = 0;k < i;k ++, w = w * omega % mod ) {
    
    
				int x = c[j + k], y = w * c[j + k + i] % mod;
				c[j + k] = ( x + y ) % mod;
				c[j + k + i] = ( x - y + mod ) % mod;
			}
		}
	}
	if( opt == -1 ) {
    
    
		int inv = qkpow( len, mod - 2 );
		for( int i = 0;i < len;i ++ )
			c[i] = c[i] * inv % mod;
	}
}

void solve( int L, int R ) {
    
    
	if( L == R ) {
    
    
		if( ! L ) dp[L] = 1;
		else dp[L] = cnt[L] & 1 ? mod - dp[L] : dp[L];//单独提出来 
		return;
	}
	int mid = ( L + R ) >> 1;
	solve( L, mid );
	len = 1; int l = 0;
	while( len <= R - L + 1 + mid - L ) len <<= 1, l ++;
	for( int i = 0;i < len;i ++ )
		r[i] = ( r[i >> 1] >> 1 ) | ( ( i & 1 ) << ( l - 1 ) );
	for( int i = 0;i <= mid - L;i ++ )
		if( s[i + L] == '<' && i + L != 0 ) f[i] = 0;
		else f[i] = cnt[i + L] & 1 ? dp[i + L] : mod - dp[i + L];//注意奇偶转换 
	for( int i = mid - L + 1;i < len;i ++ ) f[i] = 0;
	for( int i = 0;i <= R - L + 1;i ++ ) g[i] = ifac[i];
	for( int i = R - L + 2;i < len;i ++ ) g[i] = 0;
	NTT( f, 1 );
	NTT( g, 1 );
	for( int i = 0;i < len;i ++ ) f[i] = f[i] * g[i] % mod;
	NTT( f, -1 );
	for( int i = mid + 1;i <= R;i ++ ) dp[i] = ( dp[i] + f[i - L] ) % mod;
	solve( mid + 1, R );
}

signed main() {
    
    
	scanf( "%s", s + 1 );	
	int n = strlen( s + 1 );
	s[++ n] = '>';
	fac[0] = 1;
	for( int i = 1;i <= n;i ++ )
		fac[i] = fac[i - 1] * i % mod;
	ifac[n] = qkpow( fac[n], mod - 2 );
	for( int i = n - 1;~ i;i -- )
		ifac[i] = ifac[i + 1] * ( i + 1 ) % mod;
	for( int i = 1;i <= n;i ++ )
		cnt[i] = cnt[i - 1] + ( s[i] == '>' );
	solve( 0, n );
	printf( "%lld\n", dp[n] * fac[n] % mod );
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Emm_Titan/article/details/113818488