fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. #define finish(x) return cout << x << endl, 0;
  4. #define ll long long
  5.  
  6. const int N = 35001;
  7. const int K = 51;
  8. const int SZ = 22 * N;
  9.  
  10. int n, k, x, cn, a[N], las[N], seg[SZ], lc[SZ], rc[SZ], root[N], dp[N][K];
  11.  
  12. void build(int l, int r, int pos){
  13. if(l == r){
  14. seg[pos] = 1;
  15. return;
  16. }
  17. int mid = (l + r) / 2;
  18. lc[pos] = cn++;
  19. rc[pos] = cn++;
  20. build(l, mid, lc[pos]);
  21. build(mid + 1, r, rc[pos]);
  22. seg[pos] = seg[lc[pos]] + seg[rc[pos]];
  23. }
  24. void update(int l, int r, int pos, int pos2, int ind, int val){
  25. seg[pos] = seg[pos2];
  26. if(l == r){
  27. seg[pos] = val;
  28. return;
  29. }
  30. int mid = (l + r) / 2;
  31. if(ind <= mid){
  32. rc[pos] = rc[pos2];
  33. lc[pos] = cn++;
  34. update(l, mid, lc[pos], lc[pos2], ind, val);
  35. }
  36. else{
  37. lc[pos] = lc[pos2];
  38. rc[pos] = cn++;
  39. update(mid + 1, r, rc[pos], rc[pos2], ind, val);
  40. }
  41. seg[pos] = seg[lc[pos]] + seg[rc[pos]];
  42. }
  43. int query(int l, int r, int pos, int s, int e){
  44. if(r < s || l > e) return 0;
  45. if(s <= l && r <= e) return seg[pos];
  46. int mid = (l + r) / 2;
  47. return query(l, mid, lc[pos], s, e) + query(mid + 1, r, rc[pos], s, e);
  48. }
  49. int distinct(int l, int r){
  50. return query(1, n, root[r], l, r);
  51. }
  52. void solve(int k, int l, int r, int s, int e){
  53. if(r < l) return;
  54. int mid = (l + r) / 2;
  55. int opt = -1;
  56. for(int i = max(s, mid) ; i <= e ; i++){
  57. int cur = distinct(mid, i) + dp[i + 1][k - 1];
  58. if(cur > dp[mid][k]){
  59. dp[mid][k] = cur;
  60. opt = i;
  61. }
  62. }
  63. solve(k, l, mid - 1, s, opt);
  64. solve(k, mid + 1, r, opt, e);
  65. }
  66. int main(){
  67. scanf("%d%d", &n, &k);
  68. for(int i = 1 ; i <= n ; i++)
  69. scanf("%d", &a[i]);
  70. root[0] = cn++;
  71. build(1, n, root[0]);
  72. for(int i = 1 ; i <= n ; i++){
  73. root[i] = cn++;
  74. if(las[a[i]] == 0) root[i] = root[i - 1];
  75. else update(1, n, root[i], root[i - 1], las[a[i]], 0);
  76. las[a[i]] = i;
  77. }
  78. memset(las, 0, sizeof las);
  79. for(int i = n ; i >= 1 ; i--){
  80. dp[i][1] = dp[i + 1][1];
  81. if(las[a[i]] == 0) dp[i][1]++;
  82. las[a[i]] = i;
  83. }
  84. for(int i = 2 ; i <= k ; i++)
  85. solve(i, 1, n, 1, n);
  86. printf("%d\n", dp[1][k]);
  87. return 0;
  88. }
Runtime error #stdin #stdout 0.03s 49976KB
stdin
Standard input is empty
stdout
Standard output is empty