简要题意
给定一个数列 \(a\) ,求 \(\sum_\limits{i=1}^{n}\sum_\limits{j=i}^{n}f(i,j)\) 其中 \(f(i,j)\) 表示对 \(a_i\sim a_j\) 依次做如下操作:
数据范围 \(n\in[1,3\cdot 10^5], a_i\in[0,3]\) 。
题解
我们首先对于单一的序列考虑,我们首先可以发现一个性质,所有的 \(0\) 都是单独成序列的,这条性质启示我们可以把所有的 \(0\) 都拉出来单独算贡献,如果 \(a_i=0\) 那么这个 \(0\) 的贡献显然就是包含这个 \(0\) 的段数也就是 \(i(n-i+1)\) 。
那么现在这个序列里面没有 \(0\) 了,我们再手玩几组数据或者写个 \(\text{std}\) 打表可得,最多只会有 \(3\) 个序列存在,分别为 \(\{\cdots\},\{1,1,\cdots\},\{2,2,\cdots\}\) 我们考察什么时候会产生一个新的序列,手玩后发现每次碰到 \(2,1\) 或者 \(1,2\) 就会产生一个新的序列,但是需要注意的是,类似于 \(2,1,2\) 的这种输入只能产生 \(\{1,1,\cdots\}\) 这一种序列,后面的 \(1,2\) 是无效输入。而 \(3\) 的地位是特殊的, \(3\) 可以重置当前的状态,也就是说 \(2,1,2,3,1,2,1\) 这种输入可以产生 \(2\) 种序列。然后我们发现我们只需要关注外层序列大小变化的点就可以算出答案,那么问题就变成了找到 \(2\) 个外层序列大小变化的点,也即第一次出现有效 \(1,\cdots,2\) 和 \(2,\cdots,1\) 的位置,那么我们就可以记录一下每个点后的第一个有效 \(1,2\) 与 \(2,1\) 的位置,就可以算出答案,总时间复杂度为 \(O(N)\) ;
具体处理有较多细节,在写代码的时候要注意一下,但是这题有一个好处是很好写对拍,降低了调代码的难度。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| #include <bits/stdc++.h> using namespace std; #define fi first #define se second #define Mp make_pair #define pb push_back #define SZ(a) (int(a.size()))
typedef long long ll; typedef double db; typedef pair<int, int> pii; typedef vector<int> vi; #define debug(...) mt19937_64 gen(std::chrono::system_clock::now().time_since_epoch().count()); ll get(ll l, ll r) { uniform_int_distribution<ll> dist(l, r); return dist(gen); }
const int N = 300100; int n, a[N], nxt[N], nxt3[N], nxt12[N], nxt21[N], seg[N], nxt12seg[N], nxt21seg[N]; ll ans; signed main() { scanf("%d", &n); for(int i = 1; i <= n; i++) { scanf("%d", &a[i]); if(a[i] == 0) ans += 1ll * i * (n - i + 1); } nxt[n + 1] = nxt3[n + 1] = n + 1; for(int i = n, nx = n + 1, nx3 = n + 1; i; i--) { if(a[i]) nx = i; if(a[i] == 3) nx3 = i; nxt[i] = nx, nxt3[i] = nx3; } debug(" nxt: "); for(int i = 1; i <= n + 1; i++) debug("%d ", nxt[i]); debug("\n"); debug(" nxt3: "); for(int i = 1; i <= n + 1; i++) debug("%d ", nxt3[i]); debug("\n"); nxt12[n + 1] = nxt21[n + 1] = n + 1; for(int i = n; i; i--) if(a[i]) { nxt12[i] = nxt12[nxt[i + 1]]; nxt21[i] = nxt21[nxt[i + 1]]; if(a[i] == 1 && a[nxt[i + 1]] == 2) nxt12[i] = i; if(a[i] == 2 && a[nxt[i + 1]] == 1) nxt21[i] = i; } for(int i = n; i; i--) if(a[i] == 0) { nxt12[i] = nxt12[i + 1]; nxt21[i] = nxt21[i + 1]; } debug(" nxt12: "); for(int i = 1; i <= n + 1; i++) debug("%d ", nxt12[i]); debug("\n"); debug(" nxt21: "); for(int i = 1; i <= n + 1; i++) debug("%d ", nxt21[i]); debug("\n"); for(int i = n; i; i--) if(a[i] == 3) { if(nxt12[i] < nxt21[i]) { if(nxt12[i] < nxt3[i + 1]) seg[i] = 12; } else { if(nxt21[i] < nxt3[i + 1]) seg[i] = 21; } } debug(" seg: "); for(int i = 1; i <= n + 1; i++) debug("%d ", seg[i]); debug("\n"); nxt12seg[n + 1] = nxt21seg[n + 1] = n + 1; for(int i = n; i; i--) if(a[i] == 3) { nxt12seg[i] = nxt12seg[nxt3[i + 1]]; nxt21seg[i] = nxt21seg[nxt3[i + 1]]; if(seg[i] == 12) nxt12seg[i] = i; if(seg[i] == 21) nxt21seg[i] = i; } debug("nxt12seg: "); for(int i = 1; i <= n + 1; i++) debug("%d ", nxt12seg[i]); debug("\n"); debug("nxt21seg: "); for(int i = 1; i <= n + 1; i++) debug("%d ", nxt21seg[i]); debug("\n"); for(int l = 1; l <= n; l++) { int i = l; int fl12 = 0, fl21 = 0; if(min(nxt12[i], nxt21[i]) < nxt3[i]) { if(nxt12[i] < nxt21[i]) fl12 = nxt12[i]; else fl21 = nxt21[i]; } i = nxt3[i]; if(!fl12 && !fl21) { if(nxt12seg[i] < nxt21seg[i]) { i = nxt12seg[i]; i = nxt12[i]; fl12 = i; i = nxt3[i]; } else { i = nxt21seg[i]; i = nxt21[i]; fl21 = i; i = nxt3[i]; } } if(!fl12) { i = nxt12seg[i]; i = nxt12[i]; fl12 = i; i = nxt3[i]; } else if(!fl21) { i = nxt21seg[i]; i = nxt21[i]; fl21 = i; i = nxt3[i]; } if(fl12 < n + 1) fl12 = nxt[fl12 + 1]; if(fl21 < n + 1) fl21 = nxt[fl21 + 1]; int c = fl12, d = fl21; debug("%d: %d %d ", l, fl12, fl21); if(c > d) swap(c, d); debug("1: %d 2: %d 3: %d\n", c - nxt[l], d - c, n - d + 1); ans += 1 * (c - nxt[l]) + 2 * (d - c) + 3 * (n - d + 1); } printf("%lld\n", ans); debug("time=%.4lfs\n", (db)clock()/CLOCKS_PER_SEC); return 0; }
|