线段树
区间加,区间求和
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
int n;
vector<int> tag, sum;
Tree(int _n): n(_n), tag((_n + 2) * 4), sum((_n + 2) * 4) {}
Tree(const vector<int> &a): Tree(a.size()) {
function<void(int, int, int)> build = [&](int l, int r, int u) {
if (l == r) {
sum[u] = a[l - 1];
return;
}
int mid = (r - l) / 2 + l;
build(l, mid, ls);
build(mid + 1, r, rs);
sum[u] = sum[ls] + sum[rs];
};
build(1, n, 1);
}
void pushdown(int u, int len) {
tag[ls] += tag[u];
tag[rs] += tag[u];
sum[ls] += tag[u] * ((len + 1) >> 1);
sum[rs] += tag[u] * (len >> 1);
tag[u] = 0;
}
void pushup(int u) {
sum[u] = sum[ls] + sum[rs];
}
void add(int l, int r, int x, int cl, int cr, int u) {
int len = cr - cl + 1;
if (cl >= l && cr <= r) {
tag[u] += x;
sum[u] += len * x;
return;
}
if (tag[u]) {
pushdown(u, len);
}
int mid = ((cr - cl) >> 1) + cl;
if (l <= mid) {
add(l, r, x, cl, mid, ls);
}
if (r > mid) {
add(l, r, x, mid + 1, cr, rs);
}
pushup(u);
}
int query(int l, int r, int cl, int cr, int u) {
int len = cr - cl + 1;
if (cl >= l && cr <= r) {
return sum[u];
}
if (tag[u]) {
pushdown(u, len);
}
int mid = ((cr - cl) >> 1) + cl, res = 0;
if (l <= mid) {
res += query(l, r, cl, mid, ls);
}
if (r > mid) {
res += query(l, r, mid + 1, cr, rs);
}
return res;
}
void add(int l, int r, int x) {
return add(l, r, x, 1, n, 1);
}
int query(int l, int r) {
return query(l, r, 1, n, 1);
}
#undef ls
#undef rs
};
void solve() {
int n, m;
cin >> n >> m;
vector<int> a(n);
for (auto &x : a) {
cin >> x;
}
Tree tree(a);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r, x;
cin >> l >> r >> x;
tree.add(l, r, x);
} else if (op == 2) {
int l, r;
cin >> l >> r;
cout << tree.query(l, r) << endl;
}
}
}
区间加,区间最大(小)值
// 维护最大值
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
private:
int n;
vector<int> tag, ma;
void pushdown(int u, int len) {
tag[ls] += tag[u];
tag[rs] += tag[u];
ma[ls] += tag[u];
ma[rs] += tag[u];
tag[u] = 0;
}
void pushup(int u) {
ma[u] = max(ma[ls], ma[rs]);
}
void add(int l, int r, int L, int R, int x, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
tag[u] += x;
ma[u] += x;
return;
}
if (tag[u]) {
pushdown(u, len);
}
int mid = ((r - l) >> 1) + l;
if (L <= mid) {
add(l, mid, L, R, x, ls);
}
if (R > mid) {
add(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
int query(int l, int r, int L, int R, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
return ma[u];
}
if (tag[u]) {
pushdown(u, len);
}
// 改成最小值的话别忘了改这里 res = 1e9
int mid = ((r - l) >> 1) + l, res = 0;
if (L <= mid) {
res = max(res, query(l, mid, L, R, ls));
}
if (R > mid) {
res = max(res, query(mid + 1, r, L, R, rs));
}
return res;
}
public:
Tree(int _n): n(_n), tag((_n + 2) * 4), ma((_n + 2) * 4) {}
void add(int l, int r, int x) {
return add(1, n, l, r, x, 1);
}
int query(int l, int r) {
return query(1, n, l, r, 1);
}
#undef ls
#undef rs
};
区间加,区间乘,区间求和
int mod;
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
private:
int n;
vector<int> sum, mu, tag;
void pushup(int u) {
sum[u] = (sum[ls] + sum[rs]) % mod;
}
void pushdown(int l, int r, int u) {
if (mu[u] != 1) {
mu[ls] = mu[ls] * mu[u] % mod;
mu[rs] = mu[rs] * mu[u] % mod;
tag[ls] = tag[ls] * mu[u] % mod;
tag[rs] = tag[rs] * mu[u] % mod;
sum[ls] = sum[ls] * mu[u] % mod;
sum[rs] = sum[rs] * mu[u] % mod;
mu[u] = 1;
}
int mid = ((r - l) >> 1) + l;
if (tag[u]) {
sum[ls] = (sum[ls] + tag[u] * (mid - l + 1)) % mod;
sum[rs] = (sum[rs] + tag[u] * (r - mid)) % mod;
tag[ls] = (tag[ls] + tag[u]) % mod;
tag[rs] = (tag[rs] + tag[u]) % mod;
tag[u] = 0;
}
}
void mul(int l, int r, int L, int R, int x, int u) {
if (l >= L && r <= R) {
mu[u] = mu[u] * x % mod;
tag[u] = tag[u] * x % mod;
sum[u] = sum[u] * x % mod;
return;
}
if (mu[u] != 1 || tag[u]) {
pushdown(l, r, u);
}
int mid = ((r - l) >> 1) + l;
if (mid >= L) {
mul(l, mid, L, R, x, ls);
}
if (mid < R) {
mul(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
void add(int l, int r, int L, int R, int x, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
sum[u] = (sum[u] + x * len % mod) % mod;
tag[u] = (tag[u] + x) % mod;
return;
}
int mid = ((r - l) >> 1) + l;
pushdown(l, r, u);
if (mid >= L) {
add(l, mid, L, R, x, ls);
}
if (mid < R) {
add(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
int query(int l, int r, int L, int R, int u) {
if (l >= L && r <= R) {
return sum[u];
}
int mid = ((r - l) >> 1) + l, res = 0;
pushdown(l, r, u);
if (mid >= L) {
res = (res + query(l, mid, L, R, ls)) % mod;
}
if (mid < R) {
res = (res + query(mid + 1, r, L, R, rs)) % mod;
}
return res;
}
void build(const vector<int> &a, int l, int r, int u) {
sum[u] = 0, tag[u] = 0, mu[u] = 1;
if (l == r) {
sum[u] = a[l];
return;
}
int mid = ((r - l) >> 1) + l;
build(a, l, mid, ls);
build(a, mid + 1, r, rs);
pushup(u);
}
public:
Tree(int _n): n(_n) {
sum = mu = tag = vector<int>((_n + 2) * 4);
}
void add(int l, int r, int x) {
return add(1, n, l, r, x, 1);
}
void mul(int l, int r, int x) {
return mul(1, n, l, r, x, 1);
}
int query(int l, int r) {
return query(1, n, l, r, 1);
}
void build(const vector<int> &a) {
build(a, 1, n, 1);
}
#undef ls
#undef rs
};
void solve() {
int n, m;
cin >> n >> m >> mod;
Tree tr(n);
vector<int> a(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
tr.build(a);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r, x;
cin >> l >> r >> x;
tr.mul(l, r, x);
} else if (op == 2) {
int l, r, x;
cin >> l >> r >> x;
tr.add(l, r, x);
} else if (op == 3) {
int l, r;
cin >> l >> r;
cout << tr.query(l, r) << endl;
}
}
}
区间加,区间求和,动态开点
常用于权值线段树。
普通平衡树
- 插入 数
- 删除 数(若有多个相同的数,应只删除一个)
- 查询 数的排名(排名定义为比当前数小的数的个数 )
- 查询排名为 的数
- 求 的前驱(前驱定义为小于 ,且最大的数)
- 求 的后继(后继定义为大于 ,且最小的数)
// 权值线段树
struct Tree {
#define ls get_ls(u)
#define rs get_rs(u)
private:
// n 是最多元素总数,[mi, ma] 是值域
int n;
int mi, ma;
vector<int> a, sum;
vector<int> lson, rson;
int cnt = 1;
int get_ls(int u) {
if (!lson[u]) {
lson[u] = ++cnt;
}
return lson[u];
}
int get_rs(int u) {
if (!rson[u]) {
rson[u] = ++cnt;
}
return rson[u];
}
void pushdown(int u, int len) {
a[ls] += a[u];
a[rs] += a[u];
sum[ls] += a[u] * ((len + 1) >> 1);
sum[rs] += a[u] * (len >> 1);
a[u] = 0;
}
void pushup(int u) {
sum[u] = sum[ls] + sum[rs];
}
void add(int l, int r, int L, int R, int x, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
a[u] += x;
sum[u] += len * x;
return;
}
if (a[u]) {
pushdown(u, len);
}
int mid = ((r - l) >> 1) + l;
if (L <= mid) {
add(l, mid, L, R, x, ls);
}
if (R > mid) {
add(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
int query(int l, int r, int L, int R, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
return sum[u];
}
if (a[u]) {
pushdown(u, len);
}
int mid = ((r - l) >> 1) + l, res = 0;
if (L <= mid) {
res += query(l, mid, L, R, ls);
}
if (R > mid) {
res += query(mid + 1, r, L, R, rs);
}
return res;
}
public:
Tree(int _n, int _mi, int _ma): n(_n), mi(_mi), ma(_ma) {
a = sum = lson = rson = vector<int>((_n + 2) * 2);
}
void add(int l, int r, int x) {
return add(mi, ma, l, r, x, 1);
}
int query(int l, int r) {
return query(mi, ma, l, r, 1);
}
#undef ls
#undef rs
};
void solve() {
int n;
cin >> n;
int mi = -2e7, ma = 2e7;
Tree tr(n * 50, mi, ma);
// 前驱后继操作用 multiset 维护
multiset<int> b;
auto rank = [&](int x) {
int l = mi, r = ma;
while (l < r) {
int mid = ((r - l + 1) >> 1) + l;
int rk = tr.query(mi, mid - 1) + 1;
if (rk > x) {
r = mid - 1;
} else {
l = mid;
}
}
return *b.lower_bound(l);
};
while (n--) {
int op, x;
cin >> op >> x;
if (op == 1) {
tr.add(x, x, 1);
b.insert(x);
} else if (op == 2) {
tr.add(x, x, -1);
b.erase(b.find(x));
} else if (op == 3) {
cout << tr.query(mi, x - 1) + 1 << endl;
} else if (op == 4) {
cout << rank(x) << endl;
} else if (op == 5) {
cout << *(--b.lower_bound(x)) << endl;
} else if (op == 6) {
cout << *b.lower_bound(x + 1) << endl;
}
}
}
线段树合并
TODO 完善代码
常用于权值线段树,动态开点
int merge(int u, int v, int l, int r) {
if(!u) {
return v;
}
if(!v) {
return u;
}
if(u == v) {
sum[u] += sum[v];
return a;
}
int mid = ((r - l) >> 1) + l;
ls[u] = merge(ls[u], ls[v], l, mid);
rs[u] = merge(rs[u], rs[v], mid + 1, r);
pushup(u);
return u;
}
线段树分裂
TODO 完善代码
只能用于有序的序列,常用于动态开点的权值线段树
void split(int &u, int &v, int l, int r, int L, int R) {
if(l < L || r > R) {
return;
}
if(!u) {
return;
}
if(l >= L && r <= R) {
v = u;
u = 0;
return;
}
if(!q) {
q = newNode();
}
int mid = ((r - l) >> 1) + l;
if(L <= mid) {
split(ls[u], ls[v], l, mid, L, R);
}
if(R > mid) {
split(rs[u], rs[v], mid + 1, r, L, R);
}
pushup(u);
pushup(v);
}