题意:
Formally, he defines a sequence \(a_1,a_2,...,a_n\) as ''wavel'' if and only if \(a_1<a_2>a_3<a_4>a_5<a_6\)...
Now given two sequences \(a_1,a_2,...,a_n\) and \(b_1,b_2,...,b_m\), Little Q wants to find two sequences \(f_1,f_2,...,f_k(1≤f_i≤n,f_i<f_{i+1})\) and \(g_1,g_2,...,g_k(1≤g_i≤m,g_i<g_i+1)\), where \(a_{f_i}=b_{g_i}\) always holds and sequence \(a_{f_1},a_{f_2},...,a_{f_k}\) is ''wavel''.\(1<=n,m<=2000\)
\(1<=a_i,b_i<=2000\)题解:
设\(f_{i,j,k}\)
表示仅考虑\(a[1..i]\)与\(b[1..j]\),选择的两个子序列结尾分别是\(a_i\)和\(b_j\),且上升下降状态是\(k\) 时的方案数, 则\(f_{i,j,k}=\sum f_{x,y,1-k}\) ,其中\(x<i,y<j\)具体点说
定义\(f[i][j][0/1]\)为选择的两个子序列结尾分别是\(a_i\)和\(b_j\),当前为下降/上升状态的方案数 则当\(a[i] = b[j]\)的时候有\(f[i][j][0] = \sum f[x][y][1]\),其中\(x < i,y < j 且a[x] < a[i]\)\(f[i][j][1] = \sum f[x][y][0] + 1\),其中\(x < i,y < j 且a[x] > a[i]\) 暴力枚举是O(n^4)的,可以用二维树状数组去优化两维变成\(O(n^{2}log{^2}n)\) 顺序枚举i,保证了第一维递增的,只需要用树状数组去维护第二维的下标和值#include#define LL long long#define P pair using namespace std;const int N = 2e3 + 10;const int mod = 998244353;int read(){ int x = 0; char c = getchar(); while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar(); return x;}int s[2][N][N];int a[N],b[N];int n,m;void add(int &x,int y){ x += y; if(x >= mod) x -= mod;}int lowbit(int x){ return x & (-x);}int sum(int o,int i,int j){ int ans = 0; while(i){ int y = j; while(y){ add(ans,s[o][i][y]); y -= lowbit(y); } i -= lowbit(i); } return ans;}void update(int o,int i,int j,int val){ while(i <= n){ int y = j; while(y <= 2000){ add(s[o][i][y],val); y += lowbit(y); } i += lowbit(i); }}int main(){ int T; T = read(); while(T--){ n = read(),m = read(); for(int k = 0;k < 2;k++) for(int i = 1;i <= n;i++) for(int j = 1;j <= 2000;j++) s[k][i][j] = 0; for(int i = 1;i <= n;i++) a[i] = read(); for(int i = 1;i <= m;i++) b[i] = read(); int ans = 0; for(int i = 1;i <= m;i++){ for(int j = 1;j <= n;j++){ if(b[i] == a[j]){ int tmp1 = sum(1,j-1,a[j]-1),tmp2 = (mod + sum(0,j-1,2000)-sum(0,j-1,a[j]))%mod; update(0,j,a[j],tmp1);/// 0 下降 1 上升 update(1,j,a[j],(tmp2 + 1)%mod); add(ans,tmp1); add(ans,tmp2); add(ans,1); } } } printf("%d\n",ans); } return 0;}
题解的\(O(n^2)\)的做法
用\(s[i][j][0/1]\)表示\(a\)和\(b\)分别在\(1\)~\(i\)和\(1\)~\(j\)的结尾的子序列的方案 那么\(dp[i][j][k] = s[i-1][j-1][1 - k] + k==1?1:0\)\(i,j\)顺序枚举,遇到\(a[i] = b[j]\)的时候,前面可以顺便计算大于和小于它的方案,然后更新即可#include#define LL long long#define P pair using namespace std;const int N = 2e3 + 10;const int mod = 998244353;int read(){ int x = 0; char c = getchar(); while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar(); return x;}int s[2][N][N];int dp[2][N][N];int a[N],b[N];int n,m;void add(int &x,int y){ x += y; if(x >= mod) x -= mod;}int main(){ int T; T = read(); while(T--){ n = read(),m = read(); for(int i = 1;i <= n;i++) a[i] = read(); for(int i = 1;i <= m;i++) b[i] = read(); for(int k = 0;k < 2;k++) for(int i = 1;i <= n;i++) for(int j = 1;j <= m;j++) dp[k][i][j] = s[k][i][j] = 0; int ans = 0; for(int i = 1;i <= n;i++){ int tmp0 = 0,tmp1 = 0;///0 下降 1 上升 for(int j = 1;j <= m;j++){ if(a[i] == b[j]){ add(dp[0][i][j],tmp0); add(dp[1][i][j],(tmp1+1)%mod); add(ans,(dp[0][i][j]+dp[1][i][j])%mod); } else if(a[i] > b[j]){ add(tmp0, s[1][i-1][j]); }else{ add(tmp1,s[0][i-1][j]); } } for(int j = 1;j <= m;j++){ s[0][i][j] = s[0][i-1][j]; s[1][i][j] = s[1][i-1][j]; if(a[i] == b[j]){ add(s[0][i][j],dp[0][i][j]); add(s[1][i][j],dp[1][i][j]); } } } printf("%d\n",ans); } return 0;}