1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
| #include <bits/stdc++.h> using namespace std; #define fi first #define se second #define Mp make_pair #define pb push_back #define lc(k) (k << 1) #define rc(k) (lc(k) | 1) #define SZ(a) (int(a.size()))
typedef long long ll; typedef double db; typedef pair<int, int> pii; typedef vector<int> vi; #define debug(...) fprintf(stderr, __VA_ARGS__) mt19937_64 gen(std::chrono::system_clock::now().time_since_epoch().count()); ll get(ll l, ll r) { uniform_int_distribution<ll> dist(l, r); return dist(gen); }
const int N = 500100, MX = 1e5; int n, lim, mn[N], su[N], a[N]; void mmin(int k, int l, int r, int x, int v) { if(l == r) return mn[k] = v, void(); int m = l + r >> 1; if(x <= m) mmin(lc(k), l, m, x, v); else mmin(rc(k), m + 1, r, x, v); mn[k] = min(mn[lc(k)], mn[rc(k)]); } int qmin(int k, int l, int r, int L, int R) { if(L <= l && r <= R) return mn[k]; int m = l + r >> 1, res = N; if(L <= m) res = min(qmin(lc(k), l, m, L, R), res); if(R > m) res = min(qmin(rc(k), m + 1, r, L, R), res); return res; } void msum(int k, int l, int r, int x, int v) { if(l == r) return su[k] += v, void(); int m = l + r >> 1; if(x <= m) msum(lc(k), l, m, x, v); else msum(rc(k), m + 1, r, x, v); su[k] = su[lc(k)] + su[rc(k)]; } int qsum(int k, int l, int r, int L, int R) { if(L <= l && r <= R) return su[k]; int m = l + r >> 1, res = 0; if(L <= m) res += qsum(lc(k), l, m, L, R); if(R > m) res += qsum(rc(k), m + 1, r, L, R); return res; } ll calc() { vi f(n + 1, 0); for(int i = n; i; i--) { int j = qmin(1, 1, MX, a[i] + 1, min(a[i] + lim, MX)); if(j <= n) { f[i] += f[j]; f[i] += qsum(1, 1, MX, a[i] + 1, a[j]); } mmin(1, 1, MX, a[i], i); msum(1, 1, MX, a[i], 1); } for(int i = 1; i <= n; i++) { mmin(1, 1, MX, a[i], N); msum(1, 1, MX, a[i], -1); } return accumulate(f.begin(), f.end(), 0ll); } void solve() { scanf("%d %d", &n, &lim); map<int, ll> mp; for(int i = 1; i <= n; i++) scanf("%d", &a[i]), mp[a[i]]++; ll ans = 0; for(auto i : mp) ans += i.se * (i.se - 1) / 2 + i.se; ans += calc(); reverse(a + 1, a + 1 + n); ans += calc(); printf("%lld\n", ans); } signed main() { for(int i = 0; i < N; i++) mn[i] = N; int _; scanf("%d", &_); for(int cas = 1; cas <= _; cas++) solve(); debug("time=%.4lfs\n", (db)clock()/CLOCKS_PER_SEC); return 0; }
|