题目链接
https://atcoder.jp/contests/agc035/tasks/agc035_e
题解
没想出来最后一步DP宛如智障……
考虑一个数\(x\notin S\)的条件是\(x\)被删除了且在\(x\)最后一次被删除之后不能再对\(x+2\)和\(x-K\)进行删除操作。也就是说\(x+2\)和\(x-K\)的最晚删除时间要比\(x\)晚。那么我们从\(x\)往\(x+2\)和\(x-K\)连边,形成的图如果有环那么这个方案就不合法,否则合法。
如果\(K\)是偶数,显然很好算;如果\(K\)是奇数,有环的充要条件是存在一个环从某点\(a\)出发先后经过\(1\)条\(+K\)边、若干条\(-2\)边、\(1\)条\(+K\)边、若干条\(-2\)边,且经过的\(-2\)边的总数是\(K\). 这也就意味着我们需要让选出的点中不存在这样的一个环。
如果我们把图换一种方式看,把图排成一个每行\(2\)列的形状,让左边的奇数\(x\)和右边的偶数\(x+K\)在一行,从上往下每一层的两个(或一个)数是它上面一层对应的数\(+2\), 那么刚才提到的环就相当于若干条向上边、\(1\)条向右边、若干条向上边,向上边的总数是\(K\),最终从一个向左下的边下来。于是我们的目标就是要让“若干条向上边、\(1\)条向右边、若干条向上边”构成的最长的链长度不超过\((K+1)\).
这个可以从上往下DP: 设\(f[i][j][k]\)表示前\(i\)层,“从该层左边开始经过若干条向上边、\(1\)条向右边、若干条向上边的最长链”长度为\(j\),“从该层右边开始经过若干条向上边的最长链”长度为\(k\). 转移讨论一下两边都不选、选左、选右、左右都选即可。
时间复杂度\(O(n^3)\).
代码
#include<bits/stdc++.h>
#define llong long long
#define pii pair<int,int>
#define riterator reverse_iterator
using namespace std;inline int read()
{int x = 0,f = 1; char ch = getchar();for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}return x*f;
}const int N = 150;
llong P;
int n,m;void updsum(llong &x,llong y) {x = x+y>=P?x+y-P:x+y;}namespace Solve1
{llong f[N+3];void solve(){m>>=1; f[0] = 1ll;for(int i=1; i<=n; i++){for(int j=max(0,i-m-1); j<i; j++){updsum(f[i],f[j]);}}llong ans1 = 0ll,ans2 = 0ll;for(int i=max(0,(n>>1)-m); i<=(n>>1); i++) updsum(ans1,f[i]);for(int i=max(0,(n+1>>1)-m); i<=(n+1>>1); i++) updsum(ans2,f[i]);
// printf("ans1=%lld ans2=%lld\n",ans1,ans2);printf("%lld\n",ans1*ans2%P);}
}namespace Solve2
{llong f[N+3][N+3][N+3];void solve(){f[0][0][0] = 1ll;for(int i=1; i+i-m<=n; i++){for(int j=0; j<=m+1; j++) for(int k=0; k<=n; k++){updsum(f[i][0][0],f[i-1][j][k]);}if(i+i<=n){for(int j=0; j<=m+1; j++) for(int k=0; k<=n; k++){updsum(f[i][0][k+1],f[i-1][j][k]);}}if(i+i-m>=1){for(int j=0; j<=m; j++) for(int k=0; k<=n; k++){updsum(f[i][j+(j>0)][0],f[i-1][j][k]);}}if(i+i<=n&&i+i-m>=1){for(int j=0; j<=m; j++) for(int k=0; k<=n; k++){int jj = max(j+1,k+2);if(jj<=m+1){updsum(f[i][jj][k+1],f[i-1][j][k]);}}}}llong ans = 0ll;for(int j=0; j<=m+1; j++) for(int k=0; k<=n; k++){updsum(ans,f[(n+m)>>1][j][k]);}printf("%lld\n",ans);}
}int main()
{scanf("%d%d%lld",&n,&m,&P);if(!(m&1)) {Solve1::solve();}else {Solve2::solve();}return 0;
}