Introduction
lnkkerst's XCPC templates.
起手式
#pragma GCC optimize(2)
#include <algorithm>
#include <array>
#include <bitset>
#include <cmath>
#include <deque>
#include <functional>
#include <iomanip>
#include <iostream>
#include <map>
#include <numeric>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace std;
#define int long long
void solve() {}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int t = 1;
cin >> t;
while (t--) {
solve();
}
}
Vim
vimrc:
" 设置 leader 为空格键
let mapleader = " "
" 设置 UFT-8 编码
set enc=utf-8
set fenc=utf-8
set termencoding=utf-8
" 关闭 vi 兼容
set nocompatible
" 显示相对行号
set number
set relativenumber
" 启用语法高亮
set t_Co=256
syntax on
" 高亮当前行和当前列
"set cursorline
"set cursorcolumn
"highlight CursorLine ctermbg=darkgray guibg=lightgray
"highlight CursorColumn ctermbg=darkgray guibg=lightgray
" 自动缩进
set autoindent
set smartindent
" 设置合适的缩进宽度
set tabstop=2 " 设置 Tab 键宽度为 2
set shiftwidth=2 " 设置缩进宽度为 2
set expandtab " 用空格替代Tab
" 开启行内搜索时的高亮
set hlsearch
" 关闭错误的响铃提示
set noerrorbells
" 搜索时逐字符匹配
set incsearch
set ignorecase " 搜索忽略大小写
set smartcase " 如果包含大写字符,则区分大小写
" 设置颜色主题
set background=dark
colorscheme default
" 显示匹配的括号
set showmatch
" 开启剪切板支持
set clipboard=unnamedplus
" 设置取消回滚时最大操作数
set undolevels=1000
" 鼠标支持
set mouse=a
" 快速保存和退出命令
nnoremap <leader>w :w<CR> " 用 \w 快速保存
nnoremap <leader>q :q<CR> " 用 \q 快速退出
" 复制当前 buffer
nmap <leader>y ggVGy
"nmap <leader>y ggVG"+y''
" {} 括号补全
inoremap {<CR> {<CR>}<ESC>O
" 使用 x 删除时不自动复制
nnoremap x "_x
nnoremap X "_X
vnoremap x "_x
vnoremap X "_X
" 关闭搜索高亮
nnoremap <leader>l :noh<cr>
" 行内移动
nnoremap H ^
nnoremap L g_
vnoremap H ^
vnoremap L g_
常用操作:
" 在下方打开一个终端
:belowright terminal
# bash 指令
# X 下交换 esc 和 capslock,防止队友写代码
setxkbmap -option "caps:swapescape"
快读快写
本文所有代码来自 OI-wiki。
简单版本
int read() {
int x = 0, f = 1;
char ch = 0;
while (ch < '0' || ch > '9') {
if (ch == '-') {
f = -1;
}
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = x * 10 + (ch - '0');
ch = getchar();
}
return x * f;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
fread
, fwrite
版本
namespace IO {
const int MAXSIZE = 1 << 20;
char buf[MAXSIZE], *p1, *p2;
#define gc() \
(p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2) \
? EOF \
: *p1++)
int rd() {
int x = 0, f = 1;
char c = gc();
while (!isdigit(c)) {
if (c == '-') {
f = -1;
}
c = gc();
}
while (isdigit(c)) {
x = x * 10 + (c ^ 48), c = gc();
}
return x * f;
}
char pbuf[1 << 20], *pp = pbuf;
void push(const char &c) {
if (pp - pbuf == 1 << 20) {
fwrite(pbuf, 1, 1 << 20, stdout), pp = pbuf;
}
*pp++ = c;
}
void write(int x) {
static int sta[35];
int top = 0;
do {
sta[top++] = x % 10, x /= 10;
} while (x);
while (top) {
push(sta[--top] + '0');
}
}
} // namespace IO
完整带调试版
// #define DEBUG 1 // 调试开关
struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
char buf[MAXSIZE], *p1, *p2;
char pbuf[MAXSIZE], *pp;
#if DEBUG
#else
IO(): p1(buf), p2(buf), pp(pbuf) {}
~IO() {
fwrite(pbuf, 1, pp - pbuf, stdout);
}
#endif
char gc() {
#if DEBUG // 调试,可显示字符
return getchar();
#endif
if (p1 == p2) {
p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
}
return p1 == p2 ? ' ' : *p1++;
}
bool blank(char ch) {
return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
}
template <class T>
void read(T &x) {
double tmp = 1;
bool sign = 0;
x = 0;
char ch = gc();
for (; !isdigit(ch); ch = gc()) {
if (ch == '-') {
sign = 1;
}
}
for (; isdigit(ch); ch = gc()) {
x = x * 10 + (ch - '0');
}
if (ch == '.') {
for (ch = gc(); isdigit(ch); ch = gc()) {
tmp /= 10.0, x += tmp * (ch - '0');
}
}
if (sign) {
x = -x;
}
}
void read(char *s) {
char ch = gc();
for (; blank(ch); ch = gc())
;
for (; !blank(ch); ch = gc()) {
*s++ = ch;
}
*s = 0;
}
void read(char &c) {
for (c = gc(); blank(c); c = gc())
;
}
void push(const char &c) {
#if DEBUG // 调试,可显示字符
putchar(c);
#else
if (pp - pbuf == MAXSIZE) {
fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
}
*pp++ = c;
#endif
}
template <class T>
void write(T x) {
if (x < 0) {
x = -x, push('-'); // 负数输出
}
static T sta[35];
T top = 0;
do {
sta[top++] = x % 10, x /= 10;
} while (x);
while (top) {
push(sta[--top] + '0');
}
}
template <class T>
void write(T x, char lastChar) {
write(x), push(lastChar);
}
} io;
黑魔法
创建多维 vector
需要 C++14 及以上
不可初始化值的版本:
template <typename T>
std::vector<T> create_nd_vector(size_t size) {
return std::vector<T>(size);
}
template <typename T, typename... Sizes>
auto create_nd_vector(size_t first, Sizes... sizes) {
return std::vector<decltype(create_nd_vector<T>(sizes...))>(
first, create_nd_vector<T>(sizes...));
}
void solve() {
int n;
cin >> n;
vector<string> a(n + 1);
auto b = create_nd_vector<int>(n + 1, n + 1, n + 1, n + 1);
cout << typeid(b).name() << endl;
}
带初始化值的版本,初始化值必须为第一个参数:
template <typename T>
std::vector<T> create_nd_vector(T value, size_t size) {
return std::vector<T>(size);
}
template <typename T, typename... Sizes>
auto create_nd_vector(T value, size_t first, Sizes... sizes) {
return std::vector<decltype(create_nd_vector<T>(value, sizes...))>(
first, create_nd_vector<T>(value, sizes...));
}
void solve() {
int n;
cin >> n;
vector<string> a(n + 1);
auto b = create_nd_vector<int>(n + 1, n + 1, n + 1, n + 1);
cout << typeid(b).name() << endl;
}
“动态”的 std::bitset
需要 C++14 及以上
来自 CF1856E2, 利用模板展开预定义大小为 的函数。
// bitset 优化可行性 01 背包(重量和价值相等的 01 背包)
template <int len = 1>
int knapsack01(const vector<int> &a, int target) {
if (target >= len) {
return knapsack01<min(len * 2, MAX_BITSET_SIZE)>(a, target);
}
bitset<len> dp;
dp[0] = 1;
for (auto x : a) {
dp |= dp << x;
}
for (int i = target; i >= 0; --i) {
if (dp[i]) {
return i;
}
}
return 0;
}
对拍
假设待验证代码为 1.cpp
,暴力程序为 bl.cpp
,数据生成器为 dm.py
。
bash
无限循环:
while true; do
python dm.py >in
./1 <in >out1
./bl <in >outbl
if ! diff outbl out1; then
break
fi
done
限制次数(ChatGPT 生成的):
#!/bin/bash
num_tests=100 # 设定测试次数
for ((i=1; i<=num_tests; i++)); do
./generator > input.txt
./brute < input.txt > brute_output.txt
./optimized < input.txt > optimized_output.txt
if ! diff brute_output.txt optimized_output.txt > /dev/null; then
echo "Mismatch found on test $i"
echo "Input:"
cat input.txt
echo "Brute output:"
cat brute_output.txt
echo "Optimized output:"
cat optimized_output.txt
exit 1
fi
echo "Test $i passed!"
done
echo "All tests passed!"
fish
while true
python dm.py >in
cat in | ./1 >out1
cat in | ./bl >outbl
diff out1 outbl
if test $status -ne 0
break
end
end
cmd
@echo off
:loop
python dm.py > in
1.exe < in > out1
bl.exe < in > outbl
fc out1 outbl
if not errorlevel 1 goto loop
powershell
ChatGPT 写的,没经过测试。
while ($true) {
# 生成测试数据
./generator.exe > input.txt
# 运行暴力算法
./brute.exe < input.txt > brute_output.txt
# 运行优化算法
./optimized.exe < input.txt > optimized_output.txt
# 比较输出
if (!(Compare-Object (Get-Content brute_output.txt) `
(Get-Content optimized_output.txt))) {
Write-Output "Test passed!"
} else {
Write-Output "Mismatch found!"
Write-Output "Input:"
Get-Content input.txt
Write-Output "Brute output:"
Get-Content brute_output.txt
Write-Output "Optimized output:"
Get-Content optimized_output.txt
break
}
}
Python 相关
输入/输出
对于大规模输入量,input()
可能过于缓慢。使用 sys.stdin
提高效率:
import sys
input = sys.stdin.read
data = input().split()
大规模输出,使用 sys.stdout
:
import sys
sys.stdout.write('\n'.join(map(str, results)) + '\n')
内置数据结构、方法
1. 基础数据结构
list
(动态数组)
- 支持动态扩展、插入、删除、切片操作。
- 常用方法:
append
、extend
、insert
、remove
、pop
、index
、count
、sort
、reverse
arr = [1, 2, 3]
arr.append(4) # [1, 2, 3, 4]
arr.pop() # [1, 2, 3]
arr[1:3] # [2, 3]
tuple
(不可变数组)
- 不可变,支持哈希操作,可作为字典键。
- 常用操作:解包、索引、切片。
tpl = (1, 2, 3)
x, y, z = tpl # 解包
str
(字符串)
- 不可变,支持切片操作。
- 常用方法:
split
、join
、replace
、strip
、find
、startswith
、endswith
s = "hello world"
words = s.split() # ['hello', 'world']
new_s = s.replace(" ", "-") # 'hello-world'
set
(集合)
- 无序、元素唯一。
- 常用方法:
add
、remove
、discard
、union
、intersection
、difference
s = {1, 2, 3}
s.add(4) # {1, 2, 3, 4}
s.remove(2) # {1, 3, 4}
dict
(哈希表/字典)
- 键值对存储,支持快速查找。
- 常用方法:
keys
、values
、items
、get
、pop
、update
d = {"a": 1, "b": 2}
d["c"] = 3
val = d.get("a", 0) # 返回 1,如果键不存在返回默认值 0
2. collections
模块中的数据结构
collections
模块扩展了 Python 的内置数据结构,提供了更多的功能。
deque
(双端队列)
- 双端操作高效,适合队列、栈操作。
- 常用方法:
append
、appendleft
、pop
、popleft
、rotate
、extend
from collections import deque
dq = deque([1, 2, 3])
dq.appendleft(0) # [0, 1, 2, 3]
dq.pop() # [0, 1, 2]
Counter
(计数器)
- 统计元素频率。
- 常用方法:
most_common
、elements
、subtract
from collections import Counter
cnt = Counter("aabbcc")
print(cnt) # {'a': 2, 'b': 2, 'c': 2}
print(cnt.most_common(1)) # [('a', 2)]
defaultdict
(带默认值的字典)
- 为未定义的键提供默认值。
- 常用初始化方法:
int
、list
、set
from collections import defaultdict
d = defaultdict(list)
d["a"].append(1)
print(d) # {'a': [1]}
OrderedDict
(有序字典)
- 记录插入顺序(Python 3.7+ 中
dict
也支持)。
from collections import OrderedDict
od = OrderedDict()
od["a"] = 1
od["b"] = 2
print(od) # {'a': 1, 'b': 2}
namedtuple
(命名元组)
- 类似类的不可变数据结构,字段可以通过名称访问。
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
print(p.x, p.y) # 1, 2
3. heapq
模块中的堆
heapq
(最小堆)
- 实现优先队列。
- 常用方法:
heappush
、heappop
、heapify
、nlargest
、nsmallest
import heapq
heap = []
heapq.heappush(heap, 3)
heapq.heappush(heap, 1)
heapq.heappop(heap) # 返回 1
4. itertools
模块中的迭代器
常用迭代器
product
:笛卡尔积。permutations
:排列。combinations
:组合。combinations_with_replacement
:可重复组合。accumulate
:累积求和。
from itertools import product, permutations, combinations, accumulate
print(list(product([1, 2], repeat=2))) # [(1, 1), (1, 2), (2, 1), (2, 2)]
print(list(accumulate([1, 2, 3]))) # [1, 3, 6]
5. array
模块中的数组
- 提供更高效的数值存储。
from array import array
arr = array('i', [1, 2, 3]) # 'i' 表示整数类型
arr.append(4)
6. queue
模块中的队列
- 提供线程安全的队列实现。
Queue
(FIFO 队列)
from queue import Queue
q = Queue()
q.put(1)
print(q.get()) # 1
LifoQueue
(栈)
from queue import LifoQueue
stack = LifoQueue()
stack.put(1)
stack.put(2)
print(stack.get()) # 2
PriorityQueue
(优先队列)
from queue import PriorityQueue
pq = PriorityQueue()
pq.put((1, "low"))
pq.put((0, "high"))
print(pq.get()) # (0, 'high')
二分查找
from bisect import bisect_left, bisect_right
arr = [1, 2, 4, 4, 5]
pos = bisect_left(arr, 4) # 第一个 4 的索引:2
pos2 = bisect_right(arr, 4) # 第一个大于 4 的索引:4
高精度浮点数
使用内置 decimal
库:
from decimal import Decimal, getcontext
# 初始化
a = Decimal('1.1') # 推荐用字符串,避免二进制误差
b = Decimal('2.2')
支持基本的加减乘除、整除、取模、比较等。
设置精度:
getcontext().prec = 10 # 设置精度为 10 位
舍入:
result = a.quantize(Decimal('0.01')) # 保留两位小数
result = a.quantize(Decimal('1E-3')) # 保留三位小数
数学运算:
result = a.sqrt() # 平方根
result = a.exp() # e^a
result = a.ln() # ln(a)
result = a.copy_abs() # 绝对值
result = a.copy_negate() # 取相反数
分数
使用内置 fractions
库:
from fractions import Fraction
# 创建分数
f1 = Fraction(3, 4) # 分子 3,分母 4,即 3/4
f2 = Fraction('0.75') # 支持字符串初始化
f3 = Fraction(1.5) # 浮点数也可初始化
支持基本的加减乘除、整除、取模、比较等。
分子分母:
f = Fraction(5, 8)
numerator = f.numerator # 分子:5
denominator = f.denominator # 分母:8
f = Fraction(22, 7).limit_denominator(100) # 限制分母不超过 100
归并排序
void solve() {
int n;
cin >> n;
vector<int> a(n);
for (auto &i : a) {
cin >> i;
}
// [l, r), 升序
function<void(int, int)> merge_sort = [&](int l, int r) {
if (r - l <= 1) {
return;
}
int mid = l + ((r - l) >> 1);
merge_sort(l, mid);
merge_sort(mid, r);
// merge
vector<int> na;
int i = l, j = mid;
while (i < mid && j < r) {
if (a[j] < a[i]) { // 先比较 a[j] < a[i],可以保证稳定性
na.emplace_back(a[j]);
++j;
} else {
na.emplace_back(a[i]);
++i;
}
}
while (i < mid) {
na.emplace_back(a[i]);
++i;
}
while (j < r) {
na.emplace_back(a[j]);
++j;
}
// swap
for (int i = l; i < r; ++i) {
a[i] = na[i - l];
}
};
merge_sort(0, a.size());
for (auto i : a) {
cout << i << ' ';
}
cout << endl;
}
堆
二叉堆
void solve() {
int n;
cin >> n;
vector<int> a(n + 10);
int cnt = 0;
function<void(int)> pushup = [&](int u) {
if (u == 1) {
return;
}
int fa = u >> 1;
if (a[u] < a[fa]) {
swap(a[u], a[fa]);
pushup(fa);
}
};
function<void(int)> pushdown = [&](int u) {
int v = u << 1;
if (v > cnt) {
return;
}
if ((v | 1) <= cnt && a[v | 1] < a[v]) {
v |= 1;
}
if (a[v] < a[u]) {
swap(a[v], a[u]);
pushdown(v);
}
};
auto push = [&](int x) {
a[++cnt] = x;
pushup(cnt);
};
auto pop = [&]() {
a[1] = a[cnt--];
pushdown(1);
};
for (int i = 1; i <= n; ++i) {
int q;
cin >> q;
if (q == 1) {
int x;
cin >> x;
push(x);
} else if (q == 2) {
cout << a[1] << endl;
} else {
pop();
}
}
}
对顶堆
多组数据,不断读入整数,读入到 时输出并删除当前序列中位数( 不插入),偶数个数时输出较小的中位数,遇到 结束。
数据范围 。
void solve() {
int x;
priority_queue<int> lq;
priority_queue<int, vector<int>, greater<int>> rq;
// 调整对顶堆
auto adjust = [&](int sz) {
while (lq.size() < sz && !rq.empty()) {
lq.push(rq.top());
rq.pop();
}
while (lq.size() > sz) {
rq.push(lq.top());
lq.pop();
}
};
// 插入新元素
auto push = [&](int x) {
if (lq.empty() || x < lq.top()) {
lq.push(x);
} else {
rq.push(x);
}
int mid = (lq.size() + rq.size() + 1) / 2;
adjust(mid);
};
auto pop = [&]() {
lq.pop();
int mid = (lq.size() + rq.size() + 1) / 2;
adjust(mid);
};
while (scanf("%d", &x) != EOF) {
if (x == 0) {
return;
}
if (x == -1) {
printf("%d\n", lq.top());
pop();
} else {
push(x);
}
}
}
ST 表
luogu 的评测,不用 fastio 容易超时。
这里维护的是最大值。
void solve() {
int n = read(), m = read();
vector<int> a(n);
vector<array<int, 20>> f(n);
for (int i = 0; i < n; ++i) {
f[i][0] = a[i] = read();
}
// init
for (int j = 1; j < 20; ++j) {
for (int i = 0; i < n - (1 << j) + 1; ++i) {
f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
while (m--) {
int l = read(), r = read();
--l, --r;
int l2 = log2(r - l + 1);
write(max(f[l][l2], f[r - (1 << l2) + 1][l2]));
putchar('\n');
}
}
并查集
简单版本
void solve() {
int n, m;
cin >> n >> m;
vector<int> fa(n + 1);
// init
iota(fa.begin(), fa.end(), 0);
// 找父亲
function<int(int)> find = [&](int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
};
// 合并
auto merge = [&](int x, int y) {
fa[find(x)] = find(y);
};
while (m--) {
int q, x, y;
cin >> q >> x >> y;
if (q == 1) {
merge(x, y);
} else {
cout << ((find(x) == find(y)) ? "Y" : "N") << endl;
}
}
}
启发式合并
void solve() {
int n, m;
cin >> n >> m;
vector<int> fa(n + 1), sz(n + 1, 1);
iota(fa.begin(), fa.end(), 0);
// 找父亲
function<int(int)> find = [&](int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
};
// 启发式合并
auto merge = [&](int x, int y) {
int fx = find(x), fy = find(y);
if (fx == fy) {
return;
}
if (sz[fx] < sz[fy]) {
swap(fx, fy);
}
fa[fy] = fx;
sz[x] += sz[y];
};
while (m--) {
int q, x, y;
cin >> q >> x >> y;
if (q == 1) {
merge(x, y);
} else {
cout << ((find(x) == find(y)) ? "Y" : "N") << endl;
}
}
}
删除与移动
删除:将父亲设为自己,为了保证删除的元素都是叶节点,设置副本并初始化父亲为副本。 移动:保重移动的元素都在叶子节点。
实现以下功能:
- 合并两个元素所处集合。
- 移动 到 集合。
- 查询元素所在集合大小和元素和。
void solve() {
int n, m;
if (!(cin >> n >> m)) {
exit(0);
}
vector<int> fa((n + 1) * 2), sz((n + 1) * 2, 1);
vector<int> su((n + 1) * 2); // 添加统计大小
// 初始化
iota(fa.begin(), fa.begin() + n + 1, n + 1);
iota(fa.begin() + n + 1, fa.end(), n + 1);
iota(su.begin() + n + 1, su.end(), 0);
// 找父亲
function<int(int)> find = [&](int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
};
// 启发式合并(按大小)
auto merge = [&](int x, int y) {
int fx = find(x), fy = find(y);
if (fx == fy) {
return;
}
if (sz[fx] < sz[fy]) {
swap(fx, fy);
}
fa[fy] = fx;
sz[fx] += sz[fy];
su[fx] += su[fy];
};
// 删除
auto remove = [&](int x) {
--sz[find(x)];
fa[x] = x;
};
// 移动
auto move = [&](int x, int y) {
auto fx = find(x), fy = find(y);
if (fx == fy) {
return;
}
fa[x] = fy;
--sz[fx], ++sz[fy];
su[fx] -= x, su[fy] += x;
};
while (m--) {
int q;
cin >> q;
if (q == 1) {
int x, y;
cin >> x >> y;
merge(x, y);
} else if (q == 2) {
int x, y;
cin >> x >> y;
move(x, y);
} else if (q == 3) {
int x;
cin >> x;
int fx = find(x);
cout << sz[fx] << ' ' << su[fx] << endl;
}
}
}
单调队列
滑动窗口最值(Luogu-P1886):
void solve() {
int n, k;
cin >> n >> k;
vector<int> a(n);
for (auto &i : a) {
cin >> i;
}
// 滑动窗口最值
auto calc = [&](function<bool(int, int)> cmp) {
deque<int> q;
for (int i = 0; i < k; ++i) {
while (!q.empty() && cmp(a[i], a[q.back()])) {
q.pop_back();
}
q.push_back(i);
}
cout << a[q.front()] << ' ';
for (int i = k; i < n; ++i) {
while (!q.empty() && cmp(a[i], a[q.back()])) {
q.pop_back();
}
q.push_back(i);
while (q.front() <= i - k) {
q.pop_front();
}
cout << a[q.front()] << ' ';
}
cout << endl;
};
calc(less_equal<>()); // 最小值
calc(greater_equal<>()); // 最大值
}
单调栈
求元素右侧第一个大于他的元素的下标(Luogu-P5788)。
void solve() {
int n;
cin >> n;
vector<int> a(n);
for (auto &i : a) {
cin >> i;
}
stack<int> s;
vector<int> ans(n);
for (int i = n - 1; i >= 0; --i) {
while (!s.empty() && a[s.top()] <= a[i]) {
s.pop();
}
ans[i] = s.empty() ? -1 : s.top();
s.push(i);
}
for (auto i : ans) {
cout << i + 1 << ' ';
}
}
树状数组
单点加,区间和。
需要运算满足结合律且可差分,如加法(和)、乘法(积)、异或等。
- 结合律: ,其中 是一个二元运算符。
- 可差分:具有逆运算的运算,即已知 和 可以求出 。
struct Fenwick {
int n;
vector<int> a;
Fenwick(int _n): n(_n), a(_n + 10) {}
Fenwick(const vector<int> &arr) {
n = arr.size();
a = vector<int>(n + 10);
// O(n) 建树
for (int i = 1; i <= n; ++i) {
a[i] += arr[i - 1];
int j = i + lowbit(i);
if (j <= n) {
a[j] += a[i];
}
}
}
static int lowbit(int x) {
return x & -x;
}
// 单点加
void add(int k, int x) {
while (k <= n) {
a[k] += x;
k += lowbit(k);
}
}
// 查询前缀和
int query(int k) {
int res = 0;
while (k > 0) {
res += a[k];
k -= lowbit(k);
}
return res;
}
// 区间查询
int query(int l, int r) {
return query(r) - query(l - 1);
}
};
void solve() {
int n, m;
cin >> n >> m;
vector<int> a(n);
for (auto &x : a) {
cin >> x;
}
Fenwick tree(a);
while (m--) {
int q;
cin >> q;
if (q == 1) {
int pos, x;
cin >> pos >> x;
tree.add(pos, x);
} else if (q == 2) {
int l, r;
cin >> l >> r;
cout << tree.query(l, r) << endl;
}
}
}
区间加,单点查询。
维护差分数组即可。
区间加,区间和。
struct Tree {
private:
vector<int> t1, t2;
int n;
void add_(int k, int x) {
int v1 = k * x;
while (k <= n) {
t1[k] += x, t2[k] += v1;
k += lowbit(k);
}
}
int query_(vector<int> &t, int k) {
int res = 0;
while (k) {
res += t[k];
k -= lowbit(k);
}
return res;
}
public:
Tree(int _n): t1(_n + 2), t2(_n + 2), n(_n) {}
static int lowbit(int x) {
return x & -x;
}
// 区间加
void add(int l, int r, int v) {
add_(l, v);
add_(r + 1, -v);
}
// 求区间和
int query(int l, int r) {
return (r + 1) * query_(t1, r) - l * query_(t1, l - 1)
- (query_(t2, r) - query_(t2, l - 1));
}
};
void solve() {
int n, m;
cin >> n >> m;
Tree tr(n);
for (int i = 1; i <= n; ++i) {
int x;
cin >> x;
tr.add(i, i, x);
}
while (m--) {
int q;
cin >> q;
if (q == 1) {
int l, r, x;
cin >> l >> r >> x;
tr.add(l, r, x);
} else if (q == 2) {
int l, r;
cin >> l >> r;
cout << tr.query(l, r) << endl;
}
}
}
二维,子矩阵加,单点查询
改一改就是单点修改,子矩阵查询了。
struct Tree {
int n, m;
vector<vector<int>> t;
Tree(int _n, int _m): n(_n), m(_m), t(_n + 2, vector<int>(_m + 2, 0)) {}
static int lowbit(int x) {
return x & -x;
}
// 单点加
void add(int x, int y, int v) {
for (int i = x; i <= n; i += lowbit(i)) {
for (int j = y; j <= m; j += lowbit(j)) {
t[i][j] += v;
}
}
}
// 查询前缀和
int query(int x, int y) {
int res = 0;
for (int i = x; i > 0; i -= lowbit(i)) {
for (int j = y; j > 0; j -= lowbit(j)) {
res += t[i][j];
}
}
return res;
}
};
void solve() {
int n, m;
cin >> n >> m;
Tree tree(n, m);
// 区间加,维护差分数组时使用
auto addRange = [&](int x1, int y1, int x2, int y2, int v) {
tree.add(x1, y1, v);
tree.add(x1, y2 + 1, -v);
tree.add(x2 + 1, y2 + 1, v);
tree.add(x2 + 1, y1, -v);
};
// 区间查询,维护原数组时使用
auto queryRange = [&](int x1, int y1, int x2, int y2) {
return tree.query(x2, y2) - tree.query(x2, y1 - 1) - tree.query(x1 - 1, y2)
+ tree.query(x1 - 1, y1 - 1);
};
int op;
while (cin >> op) {
if (op == 1) {
int x, y, k;
cin >> x >> y >> k;
tree.add(x, y, k);
} else if (op == 2) {
int x1, y1, x2, y2;
cin >> x1 >> y1 >> x2 >> y2;
cout << queryRange(x1, y1, x2, y2) << endl;
}
}
}
二维,子矩阵加,子矩阵查询
struct Tree {
private:
vector<vector<int>> t1, t2, t3, t4;
int n, m;
static int lowbit(int x) {
return x & -x;
}
// 单点加
void add(int x, int y, int v) {
for (int i = x; i <= n; i += lowbit(i)) {
for (int j = y; j <= m; j += lowbit(j)) {
t1[i][j] += v;
t2[i][j] += v * x;
t3[i][j] += v * y;
t4[i][j] += v * x * y;
}
}
}
// 查询前缀和
int query(int x, int y) {
int res = 0;
for (int i = x; i > 0; i -= lowbit(i)) {
for (int j = y; j > 0; j -= lowbit(j)) {
res += (x + 1) * (y + 1) * t1[i][j] - (y + 1) * t2[i][j]
- (x + 1) * t3[i][j] + t4[i][j];
}
}
return res;
}
public:
Tree(int _n, int _m): n(_n), m(_m) {
t1 = t2 = t3 = t4 = vector<vector<int>>(_n + 2, vector<int>(_m + 2));
}
// 子矩阵加
void addRange(int x1, int y1, int x2, int y2, int v) {
add(x1, y1, v);
add(x1, y2 + 1, -v);
add(x2 + 1, y1, -v);
add(x2 + 1, y2 + 1, v);
}
// 子矩阵查询
int queryRange(int x1, int y1, int x2, int y2) {
return query(x2, y2) - query(x2, y1 - 1) - query(x1 - 1, y2)
+ query(x1 - 1, y1 - 1);
}
};
void solve() {
int n, m;
cin >> n >> m;
Tree tree(n, m);
int op;
while (cin >> op) {
if (op == 1) {
int x1, y1, x2, y2, k;
cin >> x1 >> y1 >> x2 >> y2 >> k;
tree.addRange(x1, y1, x2, y2, k);
} else if (op == 2) {
int x1, y1, x2, y2;
cin >> x1 >> y1 >> x2 >> y2;
cout << tree.queryRange(x1, y1, x2, y2) << endl;
}
}
}
权值树状数组
TODO 完善代码
维护 为 在 中出现的次数,对 建树状数组。
单点修改,求 kth:
// 权值树状数组查询第 k 小
int kth(int k) {
int sum = 0, x = 0;
for (int i = log2(n); ~i; --i) {
x += 1 << i; // 尝试扩展
if (x >= n || sum + t[x] >= k) // 如果扩展失败
x -= 1 << i;
else
sum += t[x];
}
return x + 1;
}
线段树
区间加,区间求和
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
int n;
vector<int> tag, sum;
Tree(int _n): n(_n), tag((_n + 2) * 4), sum((_n + 2) * 4) {}
Tree(const vector<int> &a): Tree(a.size()) {
function<void(int, int, int)> build = [&](int l, int r, int u) {
if (l == r) {
sum[u] = a[l - 1];
return;
}
int mid = (r - l) / 2 + l;
build(l, mid, ls);
build(mid + 1, r, rs);
sum[u] = sum[ls] + sum[rs];
};
build(1, n, 1);
}
void pushdown(int u, int len) {
tag[ls] += tag[u];
tag[rs] += tag[u];
sum[ls] += tag[u] * ((len + 1) >> 1);
sum[rs] += tag[u] * (len >> 1);
tag[u] = 0;
}
void pushup(int u) {
sum[u] = sum[ls] + sum[rs];
}
void add(int l, int r, int x, int cl, int cr, int u) {
int len = cr - cl + 1;
if (cl >= l && cr <= r) {
tag[u] += x;
sum[u] += len * x;
return;
}
if (tag[u]) {
pushdown(u, len);
}
int mid = ((cr - cl) >> 1) + cl;
if (l <= mid) {
add(l, r, x, cl, mid, ls);
}
if (r > mid) {
add(l, r, x, mid + 1, cr, rs);
}
pushup(u);
}
int query(int l, int r, int cl, int cr, int u) {
int len = cr - cl + 1;
if (cl >= l && cr <= r) {
return sum[u];
}
if (tag[u]) {
pushdown(u, len);
}
int mid = ((cr - cl) >> 1) + cl, res = 0;
if (l <= mid) {
res += query(l, r, cl, mid, ls);
}
if (r > mid) {
res += query(l, r, mid + 1, cr, rs);
}
return res;
}
void add(int l, int r, int x) {
return add(l, r, x, 1, n, 1);
}
int query(int l, int r) {
return query(l, r, 1, n, 1);
}
#undef ls
#undef rs
};
void solve() {
int n, m;
cin >> n >> m;
vector<int> a(n);
for (auto &x : a) {
cin >> x;
}
Tree tree(a);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r, x;
cin >> l >> r >> x;
tree.add(l, r, x);
} else if (op == 2) {
int l, r;
cin >> l >> r;
cout << tree.query(l, r) << endl;
}
}
}
区间加,区间最大(小)值
// 维护最大值
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
private:
int n;
vector<int> tag, ma;
void pushdown(int u, int len) {
tag[ls] += tag[u];
tag[rs] += tag[u];
ma[ls] += tag[u];
ma[rs] += tag[u];
tag[u] = 0;
}
void pushup(int u) {
ma[u] = max(ma[ls], ma[rs]);
}
void add(int l, int r, int L, int R, int x, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
tag[u] += x;
ma[u] += x;
return;
}
if (tag[u]) {
pushdown(u, len);
}
int mid = ((r - l) >> 1) + l;
if (L <= mid) {
add(l, mid, L, R, x, ls);
}
if (R > mid) {
add(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
int query(int l, int r, int L, int R, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
return ma[u];
}
if (tag[u]) {
pushdown(u, len);
}
// 改成最小值的话别忘了改这里 res = 1e9
int mid = ((r - l) >> 1) + l, res = 0;
if (L <= mid) {
res = max(res, query(l, mid, L, R, ls));
}
if (R > mid) {
res = max(res, query(mid + 1, r, L, R, rs));
}
return res;
}
public:
Tree(int _n): n(_n), tag((_n + 2) * 4), ma((_n + 2) * 4) {}
void add(int l, int r, int x) {
return add(1, n, l, r, x, 1);
}
int query(int l, int r) {
return query(1, n, l, r, 1);
}
#undef ls
#undef rs
};
区间加,区间乘,区间求和
int mod;
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
private:
int n;
vector<int> sum, mu, tag;
void pushup(int u) {
sum[u] = (sum[ls] + sum[rs]) % mod;
}
void pushdown(int l, int r, int u) {
if (mu[u] != 1) {
mu[ls] = mu[ls] * mu[u] % mod;
mu[rs] = mu[rs] * mu[u] % mod;
tag[ls] = tag[ls] * mu[u] % mod;
tag[rs] = tag[rs] * mu[u] % mod;
sum[ls] = sum[ls] * mu[u] % mod;
sum[rs] = sum[rs] * mu[u] % mod;
mu[u] = 1;
}
int mid = ((r - l) >> 1) + l;
if (tag[u]) {
sum[ls] = (sum[ls] + tag[u] * (mid - l + 1)) % mod;
sum[rs] = (sum[rs] + tag[u] * (r - mid)) % mod;
tag[ls] = (tag[ls] + tag[u]) % mod;
tag[rs] = (tag[rs] + tag[u]) % mod;
tag[u] = 0;
}
}
void mul(int l, int r, int L, int R, int x, int u) {
if (l >= L && r <= R) {
mu[u] = mu[u] * x % mod;
tag[u] = tag[u] * x % mod;
sum[u] = sum[u] * x % mod;
return;
}
if (mu[u] != 1 || tag[u]) {
pushdown(l, r, u);
}
int mid = ((r - l) >> 1) + l;
if (mid >= L) {
mul(l, mid, L, R, x, ls);
}
if (mid < R) {
mul(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
void add(int l, int r, int L, int R, int x, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
sum[u] = (sum[u] + x * len % mod) % mod;
tag[u] = (tag[u] + x) % mod;
return;
}
int mid = ((r - l) >> 1) + l;
pushdown(l, r, u);
if (mid >= L) {
add(l, mid, L, R, x, ls);
}
if (mid < R) {
add(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
int query(int l, int r, int L, int R, int u) {
if (l >= L && r <= R) {
return sum[u];
}
int mid = ((r - l) >> 1) + l, res = 0;
pushdown(l, r, u);
if (mid >= L) {
res = (res + query(l, mid, L, R, ls)) % mod;
}
if (mid < R) {
res = (res + query(mid + 1, r, L, R, rs)) % mod;
}
return res;
}
void build(const vector<int> &a, int l, int r, int u) {
sum[u] = 0, tag[u] = 0, mu[u] = 1;
if (l == r) {
sum[u] = a[l];
return;
}
int mid = ((r - l) >> 1) + l;
build(a, l, mid, ls);
build(a, mid + 1, r, rs);
pushup(u);
}
public:
Tree(int _n): n(_n) {
sum = mu = tag = vector<int>((_n + 2) * 4);
}
void add(int l, int r, int x) {
return add(1, n, l, r, x, 1);
}
void mul(int l, int r, int x) {
return mul(1, n, l, r, x, 1);
}
int query(int l, int r) {
return query(1, n, l, r, 1);
}
void build(const vector<int> &a) {
build(a, 1, n, 1);
}
#undef ls
#undef rs
};
void solve() {
int n, m;
cin >> n >> m >> mod;
Tree tr(n);
vector<int> a(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
tr.build(a);
while (m--) {
int op;
cin >> op;
if (op == 1) {
int l, r, x;
cin >> l >> r >> x;
tr.mul(l, r, x);
} else if (op == 2) {
int l, r, x;
cin >> l >> r >> x;
tr.add(l, r, x);
} else if (op == 3) {
int l, r;
cin >> l >> r;
cout << tr.query(l, r) << endl;
}
}
}
区间加,区间求和,动态开点
常用于权值线段树。
普通平衡树
- 插入 数
- 删除 数(若有多个相同的数,应只删除一个)
- 查询 数的排名(排名定义为比当前数小的数的个数 )
- 查询排名为 的数
- 求 的前驱(前驱定义为小于 ,且最大的数)
- 求 的后继(后继定义为大于 ,且最小的数)
// 权值线段树
struct Tree {
#define ls get_ls(u)
#define rs get_rs(u)
private:
// n 是最多元素总数,[mi, ma] 是值域
int n;
int mi, ma;
vector<int> a, sum;
vector<int> lson, rson;
int cnt = 1;
int get_ls(int u) {
if (!lson[u]) {
lson[u] = ++cnt;
}
return lson[u];
}
int get_rs(int u) {
if (!rson[u]) {
rson[u] = ++cnt;
}
return rson[u];
}
void pushdown(int u, int len) {
a[ls] += a[u];
a[rs] += a[u];
sum[ls] += a[u] * ((len + 1) >> 1);
sum[rs] += a[u] * (len >> 1);
a[u] = 0;
}
void pushup(int u) {
sum[u] = sum[ls] + sum[rs];
}
void add(int l, int r, int L, int R, int x, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
a[u] += x;
sum[u] += len * x;
return;
}
if (a[u]) {
pushdown(u, len);
}
int mid = ((r - l) >> 1) + l;
if (L <= mid) {
add(l, mid, L, R, x, ls);
}
if (R > mid) {
add(mid + 1, r, L, R, x, rs);
}
pushup(u);
}
int query(int l, int r, int L, int R, int u) {
int len = r - l + 1;
if (l >= L && r <= R) {
return sum[u];
}
if (a[u]) {
pushdown(u, len);
}
int mid = ((r - l) >> 1) + l, res = 0;
if (L <= mid) {
res += query(l, mid, L, R, ls);
}
if (R > mid) {
res += query(mid + 1, r, L, R, rs);
}
return res;
}
public:
Tree(int _n, int _mi, int _ma): n(_n), mi(_mi), ma(_ma) {
a = sum = lson = rson = vector<int>((_n + 2) * 2);
}
void add(int l, int r, int x) {
return add(mi, ma, l, r, x, 1);
}
int query(int l, int r) {
return query(mi, ma, l, r, 1);
}
#undef ls
#undef rs
};
void solve() {
int n;
cin >> n;
int mi = -2e7, ma = 2e7;
Tree tr(n * 50, mi, ma);
// 前驱后继操作用 multiset 维护
multiset<int> b;
auto rank = [&](int x) {
int l = mi, r = ma;
while (l < r) {
int mid = ((r - l + 1) >> 1) + l;
int rk = tr.query(mi, mid - 1) + 1;
if (rk > x) {
r = mid - 1;
} else {
l = mid;
}
}
return *b.lower_bound(l);
};
while (n--) {
int op, x;
cin >> op >> x;
if (op == 1) {
tr.add(x, x, 1);
b.insert(x);
} else if (op == 2) {
tr.add(x, x, -1);
b.erase(b.find(x));
} else if (op == 3) {
cout << tr.query(mi, x - 1) + 1 << endl;
} else if (op == 4) {
cout << rank(x) << endl;
} else if (op == 5) {
cout << *(--b.lower_bound(x)) << endl;
} else if (op == 6) {
cout << *b.lower_bound(x + 1) << endl;
}
}
}
线段树合并
TODO 完善代码
常用于权值线段树,动态开点
int merge(int u, int v, int l, int r) {
if(!u) {
return v;
}
if(!v) {
return u;
}
if(u == v) {
sum[u] += sum[v];
return a;
}
int mid = ((r - l) >> 1) + l;
ls[u] = merge(ls[u], ls[v], l, mid);
rs[u] = merge(rs[u], rs[v], mid + 1, r);
pushup(u);
return u;
}
线段树分裂
TODO 完善代码
只能用于有序的序列,常用于动态开点的权值线段树
void split(int &u, int &v, int l, int r, int L, int R) {
if(l < L || r > R) {
return;
}
if(!u) {
return;
}
if(l >= L && r <= R) {
v = u;
u = 0;
return;
}
if(!q) {
q = newNode();
}
int mid = ((r - l) >> 1) + l;
if(L <= mid) {
split(ls[u], ls[v], l, mid, L, R);
}
if(R > mid) {
split(rs[u], rs[v], mid + 1, r, L, R);
}
pushup(u);
pushup(v);
}
平衡树
gnu pbds 中的 tree:
普通平衡树
- 插入 数
- 删除 数(若有多个相同的数,应只删除一个)
- 查询 数的排名(排名定义为比当前数小的数的个数 )
- 查询排名为 的数
- 求 的前驱(前驱定义为小于 ,且最大的数)
- 求 的后继(后继定义为大于 ,且最小的数)
__gnu_pbds::tree
不支持可重复集合,需要自己手动处理。
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <iostream>
using namespace std;
using namespace __gnu_pbds;
#define int long long
tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>
a;
signed main() {
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
int op, x;
cin >> op >> x;
if (op == 1) {
a.insert((x << 20) + i);
} else if (op == 2) {
a.erase(a.lower_bound(x << 20));
} else if (op == 3) {
cout << a.order_of_key(x << 20) + 1 << endl;
} else if (op == 4) {
cout << (*a.find_by_order(x - 1) >> 20) << endl;
} else if (op == 5) {
cout << (*(--a.lower_bound(x << 20)) >> 20) << endl;
} else if (op == 6) {
cout << (*a.lower_bound((x << 20) + n) >> 20) << endl;
}
}
return 0;
}
std::rope
实现区间翻转:
#pragma GCC optimize(2)
#include <ext/rope>
#include <iostream>
using namespace std;
using namespace __gnu_cxx;
void solve() {
rope<int> s, rs;
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
s.push_back(i);
rs.push_back(n - i + 1);
}
while (m--) {
int l, r;
cin >> l >> r;
int rl = n - r + 1;
int rr = n - l + 1;
--l, --r;
--rl, --rr;
auto tmp = s.substr(l, r - l + 1);
auto rtmp = rs.substr(rl, rr - rl + 1);
s = s.substr(0, l) + rtmp + s.substr(r + 1, n - r);
rs = rs.substr(0, rl) + tmp + rs.substr(rr + 1, n - rr);
}
for (auto i : s) {
cout << i << ' ';
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int t = 1;
while (t--) {
solve();
}
}
Treap
有旋
struct Treap {
vector<int> ls, rs, a, p, sz, cnt;
int tot = 0, rt = 0;
int notFound;
Treap(int _n, int _notFound = (1u << 31) - 1): notFound(_notFound) {
ls = rs = a = p = sz = cnt = vector<int>(_n + 10);
}
void pushup(int u) {
sz[u] = sz[ls[u]] + sz[rs[u]] + cnt[u];
}
void rotateL(int &u) {
int t = rs[u];
rs[u] = ls[t];
ls[t] = u;
sz[t] = sz[u];
pushup(u);
u = t;
}
void rotateR(int &u) {
int t = ls[u];
ls[u] = rs[t];
rs[t] = u;
sz[t] = sz[u];
pushup(u);
u = t;
}
int newNode(int x) {
int u = ++tot;
sz[u] = 1;
cnt[u] = 1;
a[u] = x;
p[u] = rand();
return u;
}
void insert(int &u, int x) {
if (!u) {
u = newNode(x);
return;
}
++sz[u];
if (a[u] == x) {
++cnt[u];
} else if (a[u] < x) {
insert(rs[u], x);
if (p[rs[u]] < p[u]) {
rotateL(u);
}
} else {
insert(ls[u], x);
if (p[ls[u]] < p[u]) {
rotateR(u);
}
}
}
void insert(int x) {
return insert(rt, x);
}
bool remove(int &u, int x) {
if (!u) {
return false;
}
if (a[u] == x) {
if (cnt[u] > 1) {
--cnt[u];
--sz[u];
return true;
}
if (ls[u] == 0 || rs[u] == 0) {
u = ls[u] + rs[u];
return true;
} else if (p[ls[u]] < p[rs[u]]) {
rotateR(u);
return remove(u, x);
} else {
rotateL(u);
return remove(u, x);
}
} else if (a[u] < x) {
bool res = remove(rs[u], x);
if (res) {
--sz[u];
}
return res;
} else {
bool res = remove(ls[u], x);
if (res) {
--sz[u];
}
return res;
}
}
bool remove(int x) {
return remove(rt, x);
}
int queryRank(int u, int x) {
if (!u) {
return 1;
}
if (a[u] == x) {
return sz[ls[u]] + 1;
} else if (a[u] < x) {
return sz[ls[u]] + cnt[u] + queryRank(rs[u], x);
} else {
return queryRank(ls[u], x);
}
}
int queryRank(int x) {
return queryRank(rt, x);
}
int queryByRank(int u, int x) {
if (!u) {
return 0;
}
if (x <= sz[ls[u]]) {
return queryByRank(ls[u], x);
} else if (x > sz[ls[u]] + cnt[u]) {
return queryByRank(rs[u], x - sz[ls[u]] - cnt[u]);
} else {
return a[u];
}
}
auto queryByRank(int x) {
return queryByRank(rt, x);
}
int queryPrev(int u, int x) {
int res = notFound;
while (u) {
if (a[u] < x) {
res = a[u];
u = rs[u];
} else {
u = ls[u];
}
}
return res;
}
auto queryPrev(int x) {
return queryPrev(rt, x);
}
int queryNext(int u, int x) {
int res = notFound;
while (u) {
if (a[u] > x) {
res = a[u];
u = ls[u];
} else {
u = rs[u];
}
}
return res;
}
auto queryNext(int x) {
return queryNext(rt, x);
}
};
void solve() {
int n;
cin >> n;
Treap tr(n);
while (n--) {
int op, x;
cin >> op >> x;
if (op == 1) {
tr.insert(x);
} else if (op == 2) {
tr.remove(x);
} else if (op == 3) {
cout << tr.queryRank(x) << endl;
} else if (op == 4) {
cout << tr.queryByRank(x) << endl;
} else if (op == 5) {
cout << tr.queryPrev(x) << endl;
} else if (op == 6) {
cout << tr.queryNext(x) << endl;
}
}
}
无旋
区间翻转
struct Treap {
vector<int> p, sz, a, tag, ls, rs;
int tot = 0, rt = 0;
Treap(int _n) {
p = sz = a = tag = ls = rs = vector<int>(_n + 2);
}
void pushup(int u) {
sz[u] = sz[ls[u]] + sz[rs[u]] + 1;
}
int newNode(int x) {
a[++tot] = x, sz[tot] = 1, p[tot] = rand();
return tot;
}
void pushdown(int u) {
swap(ls[u], rs[u]);
if (ls[u]) {
tag[ls[u]] ^= 1;
}
if (rs[u]) {
tag[rs[u]] ^= 1;
}
tag[u] = 0;
}
int merge(int u, int v) {
if (!u || !v) {
return u + v;
}
if (p[u] < p[v]) {
if (tag[u]) {
pushdown(u);
}
rs[u] = merge(rs[u], v);
pushup(u);
return u;
}
if (tag[v]) {
pushdown(v);
}
ls[v] = merge(u, ls[v]);
pushup(v);
return v;
}
pair<int, int> split(int u, int x) {
if (!u) {
return {0, 0};
}
if (tag[u]) {
pushdown(u);
}
pair<int, int> res;
if (sz[ls[u]] < x) {
auto tmp = split(rs[u], x - sz[ls[u]] - 1);
rs[u] = tmp.first;
res = {u, tmp.second};
} else {
auto tmp = split(ls[u], x);
ls[u] = tmp.second;
res = {tmp.first, u};
}
pushup(u);
return res;
}
void dfs(int u) {
if (!u) {
return;
}
if (tag[u]) {
pushdown(u);
}
dfs(ls[u]);
cout << a[u] << ' ';
dfs(rs[u]);
}
};
void solve() {
int n, m;
cin >> n >> m;
Treap tr(n);
for (int i = 1; i <= n; ++i) {
tr.rt = tr.merge(tr.rt, tr.newNode(i));
}
for (int i = 1; i <= m; ++i) {
int l, r;
cin >> l >> r;
auto t1 = tr.split(tr.rt, l - 1);
auto t2 = tr.split(t1.second, r - l + 1);
tr.tag[t2.first] ^= 1;
tr.rt = tr.merge(t1.first, tr.merge(t2.first, t2.second));
}
tr.dfs(tr.rt);
}
Splay
区间翻转
struct Splay {
vector<int> fa, ls, rs, sz, rev;
int rt;
Splay(int _n) {
fa = ls = rs = sz = rev = vector<int>(_n + 10);
}
void pushup(int u) {
sz[u] = sz[ls[u]] + sz[rs[u]] + 1;
}
void pushdown(int u) {
if (rev[u]) {
swap(ls[u], rs[u]);
rev[ls[u]] ^= 1;
rev[rs[u]] ^= 1;
rev[u] = 0;
}
}
void rotate(int u, int &v) {
int y = fa[u], z = fa[y];
int ca = ls[y] == u ? 1 : 0;
if (y == v) {
v = u;
} else {
if (ls[z] == y) {
ls[z] = u;
} else {
rs[z] = u;
}
}
if (ca == 0) {
rs[y] = ls[u];
fa[rs[y]] = y;
ls[u] = y;
fa[y] = u;
fa[u] = z;
} else {
ls[y] = rs[u];
fa[ls[y]] = y;
rs[u] = y;
fa[y] = u;
fa[u] = z;
}
pushup(u);
pushup(y);
}
void splay(int u, int &v) {
while (u != v) {
int y = fa[u], z = fa[y];
if (y != v) {
if ((ls[y] == u) ^ (ls[z] == y)) {
rotate(u, v);
} else {
rotate(y, v);
}
}
rotate(u, v);
}
}
void build(int l, int r, int u) {
if (l > r) {
return;
}
int mid = (l + r) / 2;
if (mid < u) {
ls[u] = mid;
} else {
rs[u] = mid;
}
fa[mid] = u;
sz[mid] = 1;
if (l == r) {
return;
}
build(l, mid - 1, mid);
build(mid + 1, r, mid);
pushup(mid);
}
auto build(int l, int r) {
return build(l, r, rt);
}
int query(int u, int x) {
pushdown(u);
int s = sz[ls[u]];
if (x == s + 1) {
return u;
}
if (x <= s) {
return query(ls[u], x);
} else {
return query(rs[u], x - s - 1);
}
}
auto query(int x) {
return query(rt, x);
}
void reverse(int l, int r) {
int x = query(rt, l), y = query(rt, r + 2);
splay(x, rt);
splay(y, rs[x]);
int z = ls[y];
rev[z] ^= 1;
}
};
void solve() {
int n, m;
cin >> n >> m;
Splay tr(n);
tr.rt = (n + 3) / 2;
tr.build(1, n + 2);
for (int i = 1; i <= m; ++i) {
int l, r;
cin >> l >> r;
tr.reverse(l, r);
}
for (int i = 2; i <= n + 1; ++i) {
cout << tr.query(i) - 1 << ' ';
}
}
TODO 普通平衡树
主席树
静态区间第 小
struct Tree {
int cnt;
vector<int> su, ls, rs;
Tree(int n): cnt(0) {
n = (n + 1) << 5;
su = ls = rs = vector<int>(n);
};
int build(int l, int r) {
++cnt;
su[cnt] = 0;
int mid = (l + r) >> 1;
if (l < r) {
ls[cnt] = build(l, mid), rs[cnt] = build(mid + 1, r);
}
return cnt;
}
int update(int pre, int l, int r, int x) {
int rt = ++cnt;
ls[rt] = ls[pre];
rs[rt] = rs[pre];
su[rt] = su[pre] + 1;
if (l < r) {
int mid = (l + r) >> 1;
if (x <= mid) {
ls[rt] = update(ls[pre], l, mid, x);
} else {
rs[rt] = update(rs[pre], mid + 1, r, x);
}
}
return rt;
}
int query(int a, int b, int l, int r, int k) {
if (l >= r) {
return l;
}
int t = su[ls[b]] - su[ls[a]];
int mid = (l + r) >> 1;
if (t >= k) {
return query(ls[a], ls[b], l, mid, k);
} else {
return query(rs[a], rs[b], mid + 1, r, k - t);
}
}
};
void solve() {
int n, q;
cin >> n >> q;
vector<int> a(n), rt(n + 1);
Tree tr(n);
for (auto &i : a) {
cin >> i;
}
auto b = a;
sort(b.begin(), b.end());
b.erase(unique(b.begin(), b.end()), b.end());
rt[0] = tr.build(1, b.size());
auto getId = [&](int x) {
return lower_bound(b.begin(), b.end(), x) - b.begin();
};
for (int i = 1; i <= n; ++i) {
rt[i] = tr.update(rt[i - 1], 1, b.size(), getId(a[i - 1]) + 1);
}
while (q--) {
int l, r, k;
cin >> l >> r >> k;
int pos = tr.query(rt[l - 1], rt[r], 1, b.size(), k);
cout << b[pos - 1] << endl;
}
}
可持久化数组
struct Tree {
int cnt;
vector<int> su, ls, rs, a;
Tree(int n): cnt(0) {
n = (n + 1) << 5;
su = ls = rs = a = vector<int>(n);
};
int build(int l, int r) {
int rt = ++cnt;
if (l == r) {
su[rt] = a[l];
return rt;
}
int mid = (l + r) >> 1;
ls[rt] = build(l, mid);
rs[rt] = build(mid + 1, r);
return rt;
}
int modify(int pre, int l, int r, int x, int u) {
int rt = ++cnt;
ls[rt] = ls[pre];
rs[rt] = rs[pre];
su[rt] = su[pre];
if (l == r) {
su[rt] = x;
return rt;
}
int mid = (l + r) >> 1;
if (u <= mid) {
ls[rt] = modify(ls[pre], l, mid, x, u);
}
if (u > mid) {
rs[rt] = modify(rs[pre], mid + 1, r, x, u);
}
return rt;
}
int query(int rt, int l, int r, int u) {
if (l == r) {
return su[rt];
}
int mid = (l + r) >> 1;
if (u <= mid) {
return query(ls[rt], l, mid, u);
} else {
return query(rs[rt], mid + 1, r, u);
}
}
};
void solve() {
int n, q;
cin >> n >> q;
vector<int> rt(q + 1);
Tree tr(n);
for (int i = 1; i <= n; ++i) {
cin >> tr.a[i];
}
rt[0] = tr.build(1, n);
for (int i = 1; i <= q; ++i) {
int pre, op;
cin >> pre >> op;
if (op == 1) {
int pos, x;
cin >> pos >> x;
rt[i] = tr.modify(rt[pre], 1, n, x, pos);
} else if (op == 2) {
int pos;
cin >> pos;
cout << tr.query(rt[pre], 1, n, pos) << endl;
rt[i] = rt[pre];
}
}
}
可持久化并查集
TODO
多项式
FFT
// 看不懂,当黑盒
struct FFT {
vector<complex<double>> f;
vector<int> rev;
int limit = 1;
int l = -1;
int n, m, t;
auto read(const vector<int> &a, const vector<int> &b) {
n = a.size() - 1;
m = b.size() - 1;
t = n + m;
n = max(n, m);
limit = 1, l = -1;
while (limit <= (n << 1)) {
limit <<= 1;
++l;
}
rev.clear();
rev.resize(limit + 1);
for (int i = 1; i <= limit; ++i) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
}
f.clear();
f.resize(limit + 1);
for (int i = 0; i < a.size(); ++i) {
f[i].real(a[i]);
}
for (int i = 0; i < b.size(); ++i) {
f[i].imag(b[i]);
}
}
void fft(int type, int limit) {
for (int i = 1; i <= limit; ++i) {
if (i >= rev[i]) {
continue;
}
swap(f[i], f[rev[i]]);
}
complex<double> rt, w, x, y;
double pi = acos(-1);
for (int mid = 1; mid < limit; mid <<= 1) {
int r = mid << 1;
rt = complex<double>(cos(pi / mid), type * sin(pi / mid));
for (int j = 0; j < limit; j += r) {
w = complex<double>(1, 0);
for (int k = 0; k < mid; ++k) {
x = f[j | k];
y = w * f[j | k | mid];
f[j | k] = x + y;
f[j | k | mid] = x - y;
w = w * rt;
}
}
}
if (type == 1) {
return;
}
for (int i = 0; i <= limit; ++i) {
f[i].imag(f[i].imag() / limit);
f[i].real(f[i].real() / limit);
}
}
auto mul() {
fft(1, limit);
for (int i = 0; i <= limit; ++i) {
f[i] = f[i] * f[i];
}
fft(-1, limit);
vector<int> c(t + 1);
for (int i = 0; i <= t; ++i) {
c[i] = f[i].imag() / 2 + 0.5;
}
return c;
}
};
void solve() {
int n, m;
cin >> n >> m;
vector<int> a(n + 1), b(m + 1);
for (auto &i : a) {
cin >> i;
}
for (auto &i : b) {
cin >> i;
}
auto fft = FFT();
fft.read(a, b);
auto c = fft.mul();
for (int i = 0; i <= n + m; ++i) {
cout << c[i] << ' ';
}
cout << endl;
}
背包 DP
如无特殊说明,默认 v
为价值(value),w
为重量(weight)。
01 背包
void solve() {
int n, m;
cin >> n >> m;
vector<int> v(n + 1), w(n + 1);
vector<int> dp(m + 1);
for (int i = 1; i <= n; ++i) {
cin >> w[i] >> v[i];
}
for (int i = 1; i <= n; ++i) {
for (int j = m; j >= w[i]; --j) {
dp[j] = max(dp[j], dp[j - w[i]] + v[i]);
}
}
cout << *max_element(dp.begin() + 1, dp.end()) << endl;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
}
完全背包
void solve() {
int n, m;
cin >> m >> n;
vector<int> v(n + 1), w(n + 1);
vector<int> dp(m + 1);
for (int i = 1; i <= n; ++i) {
cin >> w[i] >> v[i];
}
for (int i = 1; i <= n; ++i) {
for (int j = w[i]; j <= m; ++j) {
dp[j] = max(dp[j], dp[j - w[i]] + v[i]);
}
}
cout << *max_element(dp.begin() + 1, dp.end()) << endl;
}
多重背包
void solve() {
int n, m;
cin >> n >> m;
vector<int> v(n + 1), w(n + 1), s(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> w[i] >> v[i] >> s[i];
}
vector<int> dp(m + 1);
auto process_01 = [&](int v, int w) {
for (int i = m; i >= w; --i) {
dp[i] = max(dp[i], dp[i - w] + v);
}
};
for (int i = 1; i <= n; ++i) {
int base = 1;
int cs = s[i], cw = w[i], cv = v[i];
while (cs > base) {
cs -= base;
process_01(cv * base, cw * base);
base *= 2;
}
process_01(cv * cs, cw * cs);
}
cout << dp[m] << endl;
}
混合背包
Luogu-P1833 按类型分别套用上面三种背包的代码即可。
void solve() {
string s, e;
int n;
cin >> s >> e >> n;
auto process_time = [&](string s) {
int h = 0, m = 0;
int pos = 0;
while (pos < s.size() && s[pos] != ':') {
h *= 10;
h += s[pos] - '0';
++pos;
}
++pos;
while (pos < s.size()) {
m *= 10;
m += s[pos] - '0';
++pos;
}
return h * 60 + m;
};
int m = process_time(e) - process_time(s);
vector<int> w(n + 1), v(n + 1), a(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> w[i] >> v[i] >> a[i];
}
vector<int> dp(m + 1);
for (int i = 1; i <= n; ++i) {
if (!a[i]) {
for (int j = w[i]; j <= m; ++j) {
dp[j] = max(dp[j], dp[j - w[i]] + v[i]);
}
} else {
int base = 1;
int k = a[i];
while (k > base) {
int cw = w[i] * base;
int cv = v[i] * base;
for (int j = m; j >= cw; --j) {
dp[j] = max(dp[j], dp[j - cw] + cv);
}
k -= base;
base <<= 1;
}
int cw = w[i] * k;
int cv = v[i] * k;
for (int j = m; j >= cw; --j) {
dp[j] = max(dp[j], dp[j - cw] + cv);
}
}
}
cout << *max_element(dp.begin() + 1, dp.end()) << endl;
}
二维费用背包
void solve() {
int n, m, t;
cin >> n >> m >> t;
vector<int> w1(n + 1), w2(n + 1);
for (int i = 1; i <= n; ++i) {
cin >> w1[i] >> w2[i];
}
vector<vector<int>> dp(m + 1, vector<int>(t + 1));
for (int i = 1; i <= n; ++i) {
for (int j = m; j >= w1[i]; --j) {
for (int k = t; k >= w2[i]; --k) {
dp[j][k] = max(dp[j][k], dp[j - w1[i]][k - w2[i]] + 1);
}
}
}
cout << dp[m][t] << endl;
}
分组背包
有 件物品和一个大小为 的背包,第 个物品的价值为 , 体积为 。同时,每个物品属于一个组,同组内最多只能选择一个物品。 求背包能装载物品的最大总价值。
void solve() {
int m, n;
cin >> m >> n;
unordered_map<int, vector<pair<int, int>>> a;
for (int i = 1; i <= n; ++i) {
int w, v, k;
cin >> w >> v >> k;
a[k].emplace_back(w, v);
}
vector<int> dp(m + 1);
for (auto &[_, vs] : a) {
for (int i = m; i >= 0; --i) {
for (auto &[w, v] : vs) {
if (i >= w) {
dp[i] = max(dp[i], dp[i - w] + v);
}
}
}
}
cout << *max_element(dp.begin() + 1, dp.end());
}
可以转化成分组背包:有依赖的背包,把物品依赖的选择方案分到同一组。
背包问题变种
输出方案
用 表示第 件物品占用空间为 的时候是否被选择,转移时记录选或不选,输出:
int cur_w = m;
vector<int> selected;
for (int i = n; i >= 1; --i) {
if (g[i][cur_w]) {
selected.emplace_back(i);
cur_w -= w[i];
}
}
求方案数量
把转移中求最大值变为求和。
01 背包:
求最优方案数量
TODO 重写并测试
for (int i = 0; i < N; i++) {
for (int j = V; j >= v[i]; j--) {
int tmp = std::max(dp[j], dp[j - v[i]] + w[i]);
int c = 0;
if (tmp == dp[j]) {
c += cnt[j]; // 如果从dp[j]转移
}
if (tmp == dp[j - v[i]] + w[i]) {
c += cnt[j - v[i]]; // 如果从dp[j-v[i]]转移
}
dp[j] = tmp;
cnt[j] = c;
}
}
int max = 0; // 寻找最优解
for (int i = 0; i <= V; i++) {
max = std::max(max, dp[i]);
}
int res = 0;
for (int i = 0; i <= V; i++) {
if (dp[i] == max) {
res += cnt[i]; // 求和最优解方案数
}
}
求第 k 优解
TODO 重写并测试
memset(dp, 0, sizeof(dp));
int i, j, p, x, y, z;
scanf("%d%d%d", &n, &m, &K);
for (i = 0; i < n; i++) {
scanf("%d", &w[i]);
}
for (i = 0; i < n; i++) {
scanf("%d", &c[i]);
}
for (i = 0; i < n; i++) {
for (j = m; j >= c[i]; j--) {
for (p = 1; p <= K; p++) {
a[p] = dp[j - c[i]][p] + w[i];
b[p] = dp[j][p];
}
a[p] = b[p] = -1;
x = y = z = 1;
while (z <= K && (a[x] != -1 || b[y] != -1)) {
if (a[x] > b[y]) {
dp[j][z] = a[x++];
} else {
dp[j][z] = b[y++];
}
if (dp[j][z] != dp[j][z - 1]) {
z++;
}
}
}
}
printf("%d\n", dp[m][K]);
状压 DP
枚举子集的子集,时间复杂度 :
for (int i = 0; i < (1 << n); ++i) {
for (int j = i; j; j = (j - 1) & i) {
// j 为 i 的子集
}
}
DP 优化
四边形不等式优化
形如以下转移方程:
需满足的条件:
- 区间包含单调性:若 ,则 。
- 四边形不等式:若 ,则
结论:
- 也满足四边形不等式。
- 假设 为 的最优决策点,那么 。
只需要记录 和 ,可以优化掉一维循环。
考场上可以直接打表记录 后验证单调性
满足四边形不等式的函数类:
性质 1:若函数 和 均满足四边形不等式(或区间包含单调性),则对于任意 ,函数 也满足四边形不等式(或区间包含单调性)。
性质 2:若存在函数 和 使得 ,则函数 满足四边形恒等式。当函数 和 单调增加时,函数 还满足区间包含单调性。
性质 3:设 是一个单调增加的凸函数,若函数 满足四边形不等式并且对区间包含关系具有单调性,则复合函数 也满足四边形不等式和区间包含单调性。
性质 4:设 是一个凸函数,若函数 满足四边形恒等式并且对区间包含关系具有单调性,则复合函数 也满足四边形不等式。
首先需要澄清一点,凸函数(Convex Function)的定义在国内教材中有分歧,此处的凸函数指的是下凸函数,即(可微时)一阶导数单调增加的函数。
void solve() {
int n, m;
io.read(n);
io.read(m);
vector<vector<int>> a(n + 10, vector<int>(n + 10));
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
io.read(a[i][j]);
}
}
auto p = a;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
p[i][j] += p[i - 1][j] + p[i][j - 1] - p[i - 1][j - 1];
}
}
auto get_sum = [&](int l, int r) {
l = max(l, 1ll);
if (r < l) {
return 0ll;
}
int res = p[r][r] - p[l - 1][r] - p[r][l - 1] + p[l - 1][l - 1];
return res;
};
vector<vector<int>> dp(n + 10, vector<int>(m + 10, 1e12));
vector<vector<int>> bt(n + 10, vector<int>(m + 10));
for (int i = 0; i <= m; ++i) {
dp[i][i] = 0;
bt[n + 1][i] = n;
}
for (int i = 0; i <= n; ++i) {
bt[i][0] = 1;
}
for (int i = 0; i <= n; ++i) {}
for (int j = 1; j <= m; ++j) {
for (int i = n; i >= j + 1; --i) {
for (int k = bt[i][j - 1]; k <= bt[i + 1][j]; ++k) {
int cur = dp[k - 1][j - 1] + get_sum(k, i) / 2;
if (cur < dp[i][j]) {
dp[i][j] = cur;
bt[i][j] = k;
}
}
}
}
io.write(dp[n][m]);
}
虚树
给一个有边权的有根树,每次询问给一堆点,求割掉一些边使得根节点无法到达。
点数量级 ,询问点之和数量级 。
对每次询问的点建虚树,跑个 dp。核心是建树。
这份代码常数极大。
void solve() {
int n;
cin >> n;
vector<vector<pair<int, int>>> e(n + 1);
for (int i = 1; i <= n - 1; ++i) {
int u, v, w;
cin >> u >> v >> w;
e[u].emplace_back(v, w);
e[v].emplace_back(u, w);
}
int cdfn = 0;
vector<int> dfn(n + 1);
vector<vector<int>> fa(n + 1, vector<int>(23));
auto fad = fa;
vector<int> dep(n + 1);
function<void(int, int, int)> dfs_init = [&](int u, int pre, int w) {
// 记录 dfs 序
dfn[u] = ++cdfn;
// 记录倍增父节点
fa[u][0] = pre;
// 记录到倍增父节点的边的最小值
fad[u][0] = w;
// 记录深度
dep[u] = dep[pre] + 1;
// 倍增处理
for (int i = 1; i <= 20; ++i) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
fad[u][i] = min(fad[u][i - 1], fad[fa[u][i - 1]][i - 1]);
}
for (auto [v, w] : e[u]) {
if (v == pre) {
continue;
}
dfs_init(v, u, w);
}
};
dfs_init(1, 0, 0);
// 求 LCA
auto lca = [&](int u, int v) {
if (dep[u] > dep[v]) {
swap(u, v);
}
int k = dep[v] - dep[u];
for (int j = 0; k; ++j, k >>= 1) {
if (k & 1) {
v = fa[v][j];
}
}
if (v == u) {
return u;
}
for (int j = 20; j >= 0 && v != u; --j) {
if (fa[u][j] != fa[v][j]) {
u = fa[u][j];
v = fa[v][j];
}
}
return fa[u][0];
};
// 获取从节点 u 到他的第 k 级祖先路径上权值最小的边
auto get_fad = [&](int u, int k) {
int ans = 1e9;
for (int i = 0; k; ++i, k >>= 1) {
if (k & 1) {
ans = min(ans, fad[u][i]);
u = fa[u][i];
}
}
return ans;
};
// 获取从 u 到 v 路径上权值最小的边
auto get_dis = [&](int u, int v) {
int lc = lca(u, v);
return min(get_fad(u, dep[u] - dep[lc]), get_fad(v, dep[v] - dep[lc]));
};
int q;
cin >> q;
while (q--) {
int k;
cin >> k;
vector<int> h(k + 1);
// 把根节点插进去,便于后续处理
h[0] = 1;
for (int i = 1; i <= k; ++i) {
cin >> h[i];
}
// 第一次按 dfs 序排序
sort(h.begin() + 1, h.begin() + k + 1, [&](int a, int b) {
return dfn[a] < dfn[b];
});
vector<int> a = {0};
// 求 dfs 序相邻的节点的 lca,加入虚树点集中
for (int i = 0; i < k; ++i) {
a.emplace_back(h[i]);
a.emplace_back(lca(h[i], h[i + 1]));
}
a.emplace_back(h[k]);
// 第二次按 dfs 序排序
sort(a.begin() + 1, a.end(), [&](int a, int b) {
return dfn[a] < dfn[b];
});
// 去重
a.erase(unique(a.begin() + 1, a.end()), a.end());
int m = a.size() - 1;
vector<vector<pair<int, int>>> ce(m + 1);
// 重新分配 id,防止后面数组开大了复杂度退化
map<int, int> id;
int cid = 1;
// 保证根节点还在 1 号
id[1] = cid;
auto get_id = [&](int u) {
if (id.find(u) != id.end()) {
return id[u];
}
return id[u] = ++cid;
};
// 虚树连边
for (int i = 1; i < m; ++i) {
int u = lca(a[i], a[i + 1]);
int v = a[i + 1];
ce[get_id(u)].emplace_back(get_id(v), get_dis(u, v));
ce[get_id(v)].emplace_back(get_id(u), get_dis(v, u));
}
// 标记关键点
vector<int> b(m + 1);
for (int i = 1; i <= k; ++i) {
b[get_id(h[i])] = 1;
}
// 问题求解的 DP
vector<int> dp(m + 1);
function<void(int, int)> dfs = [&](int u, int pre) {
for (auto [v, w] : ce[u]) {
if (v == pre) {
continue;
}
dfs(v, u);
if (!b[v]) {
dp[u] += min(dp[v], w);
} else {
dp[u] += w;
}
}
};
dfs(1, 0);
cout << dp[1] << endl;
}
}
Manacher
计算出所有的回文子串。
代码来自 OI-Wiki
计算 d1
,即长度为奇数的回文子串的中心:
vector<int> d1(n);
for (int i = 0, l = 0, r = -1; i < n; i++) {
int k = (i > r) ? 1 : min(d1[l + r - i], r - i + 1);
while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) {
k++;
}
d1[i] = k--;
if (i + k > r) {
l = i - k;
r = i + k;
}
}
计算 d2
,即长度为偶数的回文子串的中心:
vector<int> d2(n);
for (int i = 0, l = 0, r = -1; i < n; i++) {
int k = (i > r) ? 0 : min(d2[l + r - i + 1], r - i + 1);
while (0 <= i - k - 1 && i + k < n && s[i - k - 1] == s[i + k]) {
k++;
}
d2[i] = k--;
if (i + k > r) {
l = i - k - 1;
r = i + k;
}
}
自带取模类型
自带取模 int
template <int mod>
struct ModInt {
int x;
ModInt(): x(0) {}
ModInt(int y): x(y >= 0 ? y : y + mod) {}
ModInt(ll y): x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
inline int inc(const int &v) {
return v >= mod ? v - mod : v;
}
inline int dec(const int &v) {
return v < 0 ? v + mod : v;
}
inline ModInt &operator+=(const ModInt &p) {
x = inc(x + p.x);
return *this;
}
inline ModInt &operator-=(const ModInt &p) {
x = dec(x - p.x);
return *this;
}
inline ModInt &operator*=(const ModInt &p) {
x = (int)((ll)x * p.x % mod);
return *this;
}
inline ModInt inverse() const {
int a = x, b = mod, u = 1, v = 0, t;
while (b > 0) {
t = a / b, std::swap(a -= t * b, b), std::swap(u -= t * v, v);
}
return u;
}
inline ModInt &operator/=(const ModInt &p) {
*this *= p.inverse();
return *this;
}
inline ModInt operator-() const {
return -x;
}
inline friend ModInt operator+(const ModInt &lhs, const ModInt &rhs) {
return ModInt(lhs) += rhs;
}
inline friend ModInt operator-(const ModInt &lhs, const ModInt &rhs) {
return ModInt(lhs) -= rhs;
}
inline friend ModInt operator*(const ModInt &lhs, const ModInt &rhs) {
return ModInt(lhs) *= rhs;
}
inline friend ModInt operator/(const ModInt &lhs, const ModInt &rhs) {
return ModInt(lhs) /= rhs;
}
inline bool operator==(const ModInt &p) const {
return x == p.x;
}
inline bool operator!=(const ModInt &p) const {
return x != p.x;
}
inline ModInt qpow(ll n) const {
ModInt ret(1), mul(x);
while (n > 0) {
if (n & 1) {
ret *= mul;
}
mul *= mul, n >>= 1;
}
return ret;
}
inline friend std::ostream &operator<<(std::ostream &os, const ModInt &p) {
return os << p.x;
}
inline friend std::istream &operator>>(std::istream &is, ModInt &a) {
ll t;
is >> t, a = ModInt<mod>(t);
return is;
}
static int get_mod() {
return mod;
}
inline bool operator<(const ModInt &A) const {
return x < A.x;
}
inline bool operator>(const ModInt &A) const {
return x > A.x;
}
};