CodeForces - 1051D(dp)

题解:

第 i 列的状态有四种:(黑,黑),(黑,白),(白,黑),(白,白),设为0(0,0), 1(0,1), 2(1,0), 3(1,1)。

dp[i][k][j]:i 表示第 i 列,k 表示有 k 种,j 表示第 i 列的状态。

那么我们可以得到:

dp[i][k][0] = dp[i-1][k][0] + dp[i-1][k][1] + dp[i-1][k][2] + dp[i-1][k-1][3]。

同理:dp[i][k][3] = dp[i-1][k][3] + dp[i-1][k][1] + dp[i-1][k][2] + dp[i-1][k-1][0]。

因为第 i 列的状态是纯色的,只要上一列中有和这一列一样的颜色就可以叠加并且 k 不变,如果和上一个完全不一样,则种类会多加一个,所以要取 k-1 。

dp[i][k][1] = dp[i-1][k][1] + dp[i-1][k-2][2] + dp[i-1][k-1][0] + dp[i-1][k-1][3]。

同理:dp[i][k][2] = dp[i-1][k][2] + dp[i-1][k-2][1] + dp[i-1][k-1][0] + dp[i-1][k-1][3]。

因为第 i 列的状态是杂色的,所以要考虑如果上一个和当前状态一样则加上,如果完全不一样则要找 k-2 的状态加上,如果是纯色会多一种,则找 k-1 的状态加上。

#include <algorithm>
#include  <iostream>
#include   <cstdlib>
#include   <cstring>
#include    <cstdio>
#include    <string>
#include    <vector>
#include    <bitset>
#include     <stack>
#include     <cmath>
#include     <deque>
#include     <queue>
#include      <list>
#include       <set>
#include       <map>
#define line printf("---------------------------\n")
#define mem(a, b) memset(a, b, sizeof(a))
#define pi acos(-1)
using namespace std;
typedef long long ll;
const double eps = 1e-9;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const int maxn = 2000+10;

ll dp[1000+10][2000+10][4];
/**
dp[1][1][0,0] = 1;
dp[1][1][0,1] = 0;
dp[1][1][1,0] = 0;
dp[1][1][1,1] = 1;
dp[1][2][0,0] = 0;
dp[1][2][0,1] = 1;
dp[1][2][1,0] = 1;
dp[1][2][1,1] = 0;
dp[2][1][0,0] = 1;
dp[2][1][0,1] = 0;
dp[2][1][1,0] = 0;
dp[2][1][1,1] = 1;
dp[2][2][0,0] = 1+1+1+0;
dp[2][2][0,1] = 1+1+1+0;
dp[2][2][1,0] = 1+1+1+0;
dp[2][2][1,1] = 1+1+1+0;
dp[i][k][0] = dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][0]+dp[i-1][k-1][3];
dp[i][k][3] = dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][3]+dp[i-1][k-1][0];
dp[i][k][1] = dp[i-1][k][1]+dp[i-1][k-2][2]+dp[i-1][k-1][0]+dp[i-1][k-1][3];
dp[i][k][2] = dp[i-1][k][2]+dp[i-1][k-2][1]+dp[i-1][k-1][0]+dp[i-1][k-1][3];
*/
int main(){
	int n, K;
	while(~scanf("%d %d", &n, &K)){
		mem(dp, 0LL);
		dp[1][1][0] = 1;
		dp[1][1][3] = 1;
		dp[1][2][1] = 1;
		dp[1][2][2] = 1;
		for(int i = 2; i <= n; i++){
			dp[i][1][0] = dp[i-1][1][0];
			dp[i][1][1] = dp[i-1][1][1];
			dp[i][1][2] = dp[i-1][1][2];
			dp[i][1][3] = dp[i-1][1][3];
			for(int k = 2; k <= K; k++){
				dp[i][k][0] = (dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][0]+dp[i-1][k-1][3]) % mod;
				dp[i][k][3] = (dp[i-1][k][1]+dp[i-1][k][2]+dp[i-1][k][3]+dp[i-1][k-1][0]) % mod;
				dp[i][k][1] = (dp[i-1][k][1]+dp[i-1][k-2][2]+dp[i-1][k-1][0]+dp[i-1][k-1][3]) % mod;
				dp[i][k][2] = (dp[i-1][k][2]+dp[i-1][k-2][1]+dp[i-1][k-1][0]+dp[i-1][k-1][3]) % mod;
			}
		}
		printf("%lld\n", (dp[n][K][0]+dp[n][K][1]+dp[n][K][2]+dp[n][K][3])%mod);
	}
}

猜你喜欢

转载自blog.csdn.net/yanhu6955/article/details/82956672