0%

CF1762F.Good Pairs

简要题意

给定一个 \(n\) 个数的序列 \(a\) 和一个数 \(k\) ,定义一个子区间 \([l,r]\) 是好的,当且仅当这个区间内存在一个子序列满足:

  • 包含 \(a_l\)\(a_r\)
  • 任意两个相邻数的差的绝对值不超过 \(k\)

请求出所有合法的子区间数量。多组数据,\(1\le \sum n\le5\times10^5,0\le k\le10^5,1\le a_i\le10^5\)

题解

我们发现最后的走的路径一定是单调的,否则不优,故问题变为了讨论 \(a_l<a_r,a_l=a_r,a_l>a_r\) 这三种情况,期中 \(a_l<a_r\)\(a_l>a_r\) 对称,可以只考虑一种, \(a_l=a_r\) 很好统计。

考虑 \(a_l<a_r\)\(f_i\) 表示 \(a_i\) 开始走上升的路能够到达的点数,从后往前做有转移 \(f_i=f_j+\text{calc}(a_i+1,a_j)\) 其中 \(j\) 为第一个 \(a_j>a_i\) 的位置, \(\text{calc}(a_i+1,a_j)\) 代表 \(a_i\)\(a_n\) 之间在 \([a_i+1,a_j]\) 之间的数的数量,用线段树简单维护即可,时间复杂度 \(O(n\log A)\)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define Mp make_pair
#define pb push_back
#define lc(k) (k << 1)
#define rc(k) (lc(k) | 1)
#define SZ(a) (int(a.size()))

typedef long long ll;
typedef double db;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define debug(...) fprintf(stderr, __VA_ARGS__)
mt19937_64 gen(std::chrono::system_clock::now().time_since_epoch().count());
ll get(ll l, ll r) { uniform_int_distribution<ll> dist(l, r); return dist(gen); }

const int N = 500100, MX = 1e5;
int n, lim, mn[N], su[N], a[N];
void mmin(int k, int l, int r, int x, int v) {
if(l == r) return mn[k] = v, void();
int m = l + r >> 1;
if(x <= m) mmin(lc(k), l, m, x, v);
else mmin(rc(k), m + 1, r, x, v);
mn[k] = min(mn[lc(k)], mn[rc(k)]);
}
int qmin(int k, int l, int r, int L, int R) {
if(L <= l && r <= R) return mn[k];
int m = l + r >> 1, res = N;
if(L <= m) res = min(qmin(lc(k), l, m, L, R), res);
if(R > m) res = min(qmin(rc(k), m + 1, r, L, R), res);
return res;
}
void msum(int k, int l, int r, int x, int v) {
if(l == r) return su[k] += v, void();
int m = l + r >> 1;
if(x <= m) msum(lc(k), l, m, x, v);
else msum(rc(k), m + 1, r, x, v);
su[k] = su[lc(k)] + su[rc(k)];
}
int qsum(int k, int l, int r, int L, int R) {
if(L <= l && r <= R) return su[k];
int m = l + r >> 1, res = 0;
if(L <= m) res += qsum(lc(k), l, m, L, R);
if(R > m) res += qsum(rc(k), m + 1, r, L, R);
return res;
}
ll calc() {
vi f(n + 1, 0);
for(int i = n; i; i--) {
int j = qmin(1, 1, MX, a[i] + 1, min(a[i] + lim, MX));
if(j <= n) {
f[i] += f[j];
f[i] += qsum(1, 1, MX, a[i] + 1, a[j]);
}
mmin(1, 1, MX, a[i], i);
msum(1, 1, MX, a[i], 1);
}
for(int i = 1; i <= n; i++) {
mmin(1, 1, MX, a[i], N);
msum(1, 1, MX, a[i], -1);
}
return accumulate(f.begin(), f.end(), 0ll);
}
void solve() {
scanf("%d %d", &n, &lim); map<int, ll> mp;
for(int i = 1; i <= n; i++) scanf("%d", &a[i]), mp[a[i]]++;
ll ans = 0;
for(auto i : mp) ans += i.se * (i.se - 1) / 2 + i.se;
ans += calc();
reverse(a + 1, a + 1 + n);
ans += calc();
printf("%lld\n", ans);
}
signed main() {
for(int i = 0; i < N; i++) mn[i] = N;
int _; scanf("%d", &_); for(int cas = 1; cas <= _; cas++) solve();
debug("time=%.4lfs\n", (db)clock()/CLOCKS_PER_SEC);
return 0;
}