Introduction

lnkkerst's XCPC templates.

起手式

#pragma GCC optimize(2)
#include <algorithm>
#include <array>
#include <bitset>
#include <cmath>
#include <deque>
#include <functional>
#include <iomanip>
#include <iostream>
#include <map>
#include <numeric>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace std;

#define int long long

void solve() {}

signed main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  int t = 1;
  cin >> t;
  while (t--) {
    solve();
  }
}

Vim

vimrc:

" 设置 leader 为空格键
let mapleader = " "

" 设置 UFT-8 编码
set enc=utf-8
set fenc=utf-8
set termencoding=utf-8

" 关闭 vi 兼容
set nocompatible

" 显示相对行号
set number
set relativenumber

" 启用语法高亮
set t_Co=256
syntax on

" 高亮当前行和当前列
"set cursorline
"set cursorcolumn
"highlight CursorLine ctermbg=darkgray guibg=lightgray
"highlight CursorColumn ctermbg=darkgray guibg=lightgray

" 自动缩进
set autoindent
set smartindent

" 设置合适的缩进宽度
set tabstop=2        " 设置 Tab 键宽度为 2
set shiftwidth=2     " 设置缩进宽度为 2
set expandtab        " 用空格替代Tab

" 开启行内搜索时的高亮
set hlsearch

" 关闭错误的响铃提示
set noerrorbells

" 搜索时逐字符匹配
set incsearch
set ignorecase       " 搜索忽略大小写
set smartcase        " 如果包含大写字符,则区分大小写

" 设置颜色主题
set background=dark
colorscheme default

" 显示匹配的括号
set showmatch

" 开启剪切板支持
set clipboard=unnamedplus

" 设置取消回滚时最大操作数
set undolevels=1000

" 鼠标支持
set mouse=a

" 快速保存和退出命令
nnoremap <leader>w :w<CR>       " 用 \w 快速保存
nnoremap <leader>q :q<CR>       " 用 \q 快速退出

" 复制当前 buffer
nmap <leader>y ggVGy
"nmap <leader>y ggVG"+y''

" {} 括号补全
inoremap {<CR> {<CR>}<ESC>O

" 使用 x 删除时不自动复制
nnoremap x "_x
nnoremap X "_X
vnoremap x "_x
vnoremap X "_X

" 关闭搜索高亮
nnoremap <leader>l :noh<cr>

" 行内移动
nnoremap H ^
nnoremap L g_
vnoremap H ^
vnoremap L g_

常用操作:

" 在下方打开一个终端
:belowright terminal
# bash 指令

# X 下交换 esc 和 capslock,防止队友写代码
setxkbmap -option "caps:swapescape"

快读快写

本文所有代码来自 OI-wiki

简单版本

int read() {
  int x = 0, f = 1;
  char ch = 0;
  while (ch < '0' || ch > '9') {
    if (ch == '-') {
      f = -1;
    }
    ch = getchar();
  }
  while (ch >= '0' && ch <= '9') {
    x = x * 10 + (ch - '0');
    ch = getchar();
  }
  return x * f;
}

void write(int x) {
  if (x < 0) {
    x = -x;
    putchar('-');
  }
  if (x > 9) {
    write(x / 10);
  }
  putchar(x % 10 + '0');
}

fread, fwrite 版本

namespace IO {
const int MAXSIZE = 1 << 20;
char buf[MAXSIZE], *p1, *p2;
#define gc()                                                               \
  (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2) \
     ? EOF                                                                 \
     : *p1++)

int rd() {
  int x = 0, f = 1;
  char c = gc();
  while (!isdigit(c)) {
    if (c == '-') {
      f = -1;
    }
    c = gc();
  }
  while (isdigit(c)) {
    x = x * 10 + (c ^ 48), c = gc();
  }
  return x * f;
}

char pbuf[1 << 20], *pp = pbuf;

void push(const char &c) {
  if (pp - pbuf == 1 << 20) {
    fwrite(pbuf, 1, 1 << 20, stdout), pp = pbuf;
  }
  *pp++ = c;
}

void write(int x) {
  static int sta[35];
  int top = 0;
  do {
    sta[top++] = x % 10, x /= 10;
  } while (x);
  while (top) {
    push(sta[--top] + '0');
  }
}
} // namespace IO

完整带调试版

// #define DEBUG 1  // 调试开关
struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
  char buf[MAXSIZE], *p1, *p2;
  char pbuf[MAXSIZE], *pp;
#if DEBUG
#else
  IO(): p1(buf), p2(buf), pp(pbuf) {}

  ~IO() {
    fwrite(pbuf, 1, pp - pbuf, stdout);
  }
#endif
  char gc() {
#if DEBUG // 调试,可显示字符
    return getchar();
#endif
    if (p1 == p2) {
      p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
    }
    return p1 == p2 ? ' ' : *p1++;
  }

  bool blank(char ch) {
    return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
  }

  template <class T>
  void read(T &x) {
    double tmp = 1;
    bool sign = 0;
    x = 0;
    char ch = gc();
    for (; !isdigit(ch); ch = gc()) {
      if (ch == '-') {
        sign = 1;
      }
    }
    for (; isdigit(ch); ch = gc()) {
      x = x * 10 + (ch - '0');
    }
    if (ch == '.') {
      for (ch = gc(); isdigit(ch); ch = gc()) {
        tmp /= 10.0, x += tmp * (ch - '0');
      }
    }
    if (sign) {
      x = -x;
    }
  }

  void read(char *s) {
    char ch = gc();
    for (; blank(ch); ch = gc())
      ;
    for (; !blank(ch); ch = gc()) {
      *s++ = ch;
    }
    *s = 0;
  }

  void read(char &c) {
    for (c = gc(); blank(c); c = gc())
      ;
  }

  void push(const char &c) {
#if DEBUG // 调试,可显示字符
    putchar(c);
#else
    if (pp - pbuf == MAXSIZE) {
      fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
    }
    *pp++ = c;
#endif
  }

  template <class T>
  void write(T x) {
    if (x < 0) {
      x = -x, push('-'); // 负数输出
    }
    static T sta[35];
    T top = 0;
    do {
      sta[top++] = x % 10, x /= 10;
    } while (x);
    while (top) {
      push(sta[--top] + '0');
    }
  }

  template <class T>
  void write(T x, char lastChar) {
    write(x), push(lastChar);
  }
} io;

黑魔法

创建多维 vector

需要 C++14 及以上

不可初始化值的版本:

template <typename T>
std::vector<T> create_nd_vector(size_t size) {
  return std::vector<T>(size);
}

template <typename T, typename... Sizes>
auto create_nd_vector(size_t first, Sizes... sizes) {
  return std::vector<decltype(create_nd_vector<T>(sizes...))>(
    first, create_nd_vector<T>(sizes...));
}

void solve() {
  int n;
  cin >> n;
  vector<string> a(n + 1);
  auto b = create_nd_vector<int>(n + 1, n + 1, n + 1, n + 1);
  cout << typeid(b).name() << endl;
}

带初始化值的版本,初始化值必须为第一个参数:

template <typename T>
std::vector<T> create_nd_vector(T value, size_t size) {
  return std::vector<T>(size);
}

template <typename T, typename... Sizes>
auto create_nd_vector(T value, size_t first, Sizes... sizes) {
  return std::vector<decltype(create_nd_vector<T>(value, sizes...))>(
    first, create_nd_vector<T>(value, sizes...));
}

void solve() {
  int n;
  cin >> n;
  vector<string> a(n + 1);
  auto b = create_nd_vector<int>(n + 1, n + 1, n + 1, n + 1);
  cout << typeid(b).name() << endl;
}

“动态”的 std::bitset

需要 C++14 及以上

来自 CF1856E2, 利用模板展开预定义大小为 的函数。

// bitset 优化可行性 01 背包(重量和价值相等的 01 背包)
template <int len = 1>
int knapsack01(const vector<int> &a, int target) {
  if (target >= len) {
    return knapsack01<min(len * 2, MAX_BITSET_SIZE)>(a, target);
  }
  bitset<len> dp;
  dp[0] = 1;
  for (auto x : a) {
    dp |= dp << x;
  }
  for (int i = target; i >= 0; --i) {
    if (dp[i]) {
      return i;
    }
  }
  return 0;
}

对拍

假设待验证代码为 1.cpp,暴力程序为 bl.cpp,数据生成器为 dm.py

bash

无限循环:

while true; do
  python dm.py >in
  ./1 <in >out1
  ./bl <in >outbl
  if ! diff outbl out1; then
    break
  fi
done

限制次数(ChatGPT 生成的):

#!/bin/bash
num_tests=100  # 设定测试次数
for ((i=1; i<=num_tests; i++)); do
    ./generator > input.txt
    ./brute < input.txt > brute_output.txt
    ./optimized < input.txt > optimized_output.txt

    if ! diff brute_output.txt optimized_output.txt > /dev/null; then
        echo "Mismatch found on test $i"
        echo "Input:"
        cat input.txt
        echo "Brute output:"
        cat brute_output.txt
        echo "Optimized output:"
        cat optimized_output.txt
        exit 1
    fi
    echo "Test $i passed!"
done
echo "All tests passed!"

fish

while true
    python dm.py >in
    cat in | ./1 >out1
    cat in | ./bl >outbl
    diff out1 outbl
    if test $status -ne 0
        break
    end
end

cmd

@echo off
:loop
  python dm.py > in
  1.exe < in > out1
  bl.exe < in > outbl
  fc out1 outbl
if not errorlevel 1 goto loop

powershell

ChatGPT 写的,没经过测试。

while ($true) {
  # 生成测试数据
  ./generator.exe > input.txt

  # 运行暴力算法
  ./brute.exe < input.txt > brute_output.txt

  # 运行优化算法
  ./optimized.exe < input.txt > optimized_output.txt

  # 比较输出
  if (!(Compare-Object (Get-Content brute_output.txt) `
                      (Get-Content optimized_output.txt))) {
    Write-Output "Test passed!"
  } else {
    Write-Output "Mismatch found!"
    Write-Output "Input:"
    Get-Content input.txt
    Write-Output "Brute output:"
    Get-Content brute_output.txt
    Write-Output "Optimized output:"
    Get-Content optimized_output.txt
    break
  }
}

Python 相关

输入/输出

对于大规模输入量,input() 可能过于缓慢。使用 sys.stdin 提高效率:

import sys
input = sys.stdin.read
data = input().split()

大规模输出,使用 sys.stdout

import sys
sys.stdout.write('\n'.join(map(str, results)) + '\n')

内置数据结构、方法

1. 基础数据结构

list(动态数组)

  • 支持动态扩展、插入、删除、切片操作。
  • 常用方法:appendextendinsertremovepopindexcountsortreverse
arr = [1, 2, 3]
arr.append(4)       # [1, 2, 3, 4]
arr.pop()           # [1, 2, 3]
arr[1:3]            # [2, 3]

tuple(不可变数组)

  • 不可变,支持哈希操作,可作为字典键。
  • 常用操作:解包、索引、切片。
tpl = (1, 2, 3)
x, y, z = tpl      # 解包

str(字符串)

  • 不可变,支持切片操作。
  • 常用方法:splitjoinreplacestripfindstartswithendswith
s = "hello world"
words = s.split()        # ['hello', 'world']
new_s = s.replace(" ", "-")  # 'hello-world'

set(集合)

  • 无序、元素唯一。
  • 常用方法:addremovediscardunionintersectiondifference
s = {1, 2, 3}
s.add(4)                # {1, 2, 3, 4}
s.remove(2)             # {1, 3, 4}

dict(哈希表/字典)

  • 键值对存储,支持快速查找。
  • 常用方法:keysvaluesitemsgetpopupdate
d = {"a": 1, "b": 2}
d["c"] = 3
val = d.get("a", 0)    # 返回 1,如果键不存在返回默认值 0

2. collections 模块中的数据结构

collections 模块扩展了 Python 的内置数据结构,提供了更多的功能。

deque(双端队列)

  • 双端操作高效,适合队列、栈操作。
  • 常用方法:appendappendleftpoppopleftrotateextend
from collections import deque
dq = deque([1, 2, 3])
dq.appendleft(0)       # [0, 1, 2, 3]
dq.pop()               # [0, 1, 2]

Counter(计数器)

  • 统计元素频率。
  • 常用方法:most_commonelementssubtract
from collections import Counter
cnt = Counter("aabbcc")
print(cnt)             # {'a': 2, 'b': 2, 'c': 2}
print(cnt.most_common(1))  # [('a', 2)]

defaultdict(带默认值的字典)

  • 为未定义的键提供默认值。
  • 常用初始化方法:intlistset
from collections import defaultdict
d = defaultdict(list)
d["a"].append(1)
print(d)  # {'a': [1]}

OrderedDict(有序字典)

  • 记录插入顺序(Python 3.7+ 中 dict 也支持)。
from collections import OrderedDict
od = OrderedDict()
od["a"] = 1
od["b"] = 2
print(od)  # {'a': 1, 'b': 2}

namedtuple(命名元组)

  • 类似类的不可变数据结构,字段可以通过名称访问。
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
print(p.x, p.y)  # 1, 2

3. heapq 模块中的堆

heapq(最小堆)

  • 实现优先队列。
  • 常用方法:heappushheappopheapifynlargestnsmallest
import heapq
heap = []
heapq.heappush(heap, 3)
heapq.heappush(heap, 1)
heapq.heappop(heap)  # 返回 1

4. itertools 模块中的迭代器

常用迭代器

  • product:笛卡尔积。
  • permutations:排列。
  • combinations:组合。
  • combinations_with_replacement:可重复组合。
  • accumulate:累积求和。
from itertools import product, permutations, combinations, accumulate
print(list(product([1, 2], repeat=2)))  # [(1, 1), (1, 2), (2, 1), (2, 2)]
print(list(accumulate([1, 2, 3])))     # [1, 3, 6]

5. array 模块中的数组

  • 提供更高效的数值存储。
from array import array
arr = array('i', [1, 2, 3])  # 'i' 表示整数类型
arr.append(4)

6. queue 模块中的队列

  • 提供线程安全的队列实现。

Queue(FIFO 队列)

from queue import Queue
q = Queue()
q.put(1)
print(q.get())  # 1

LifoQueue(栈)

from queue import LifoQueue
stack = LifoQueue()
stack.put(1)
stack.put(2)
print(stack.get())  # 2

PriorityQueue(优先队列)

from queue import PriorityQueue
pq = PriorityQueue()
pq.put((1, "low"))
pq.put((0, "high"))
print(pq.get())  # (0, 'high')

二分查找

from bisect import bisect_left, bisect_right
arr = [1, 2, 4, 4, 5]
pos = bisect_left(arr, 4)   # 第一个 4 的索引:2
pos2 = bisect_right(arr, 4)  # 第一个大于 4 的索引:4

高精度浮点数

使用内置 decimal 库:

from decimal import Decimal, getcontext

# 初始化
a = Decimal('1.1')   # 推荐用字符串,避免二进制误差
b = Decimal('2.2')

支持基本的加减乘除、整除、取模、比较等。

设置精度:

getcontext().prec = 10  # 设置精度为 10 位

舍入:

result = a.quantize(Decimal('0.01'))     # 保留两位小数
result = a.quantize(Decimal('1E-3'))    # 保留三位小数

数学运算:

result = a.sqrt()    # 平方根
result = a.exp()     # e^a
result = a.ln()      # ln(a)
result = a.copy_abs()   # 绝对值
result = a.copy_negate()  # 取相反数

分数

使用内置 fractions 库:

from fractions import Fraction

# 创建分数
f1 = Fraction(3, 4)          # 分子 3,分母 4,即 3/4
f2 = Fraction('0.75')        # 支持字符串初始化
f3 = Fraction(1.5)           # 浮点数也可初始化

支持基本的加减乘除、整除、取模、比较等。

分子分母:

f = Fraction(5, 8)
numerator = f.numerator   # 分子:5
denominator = f.denominator  # 分母:8
f = Fraction(22, 7).limit_denominator(100)  # 限制分母不超过 100

归并排序

void solve() {
  int n;
  cin >> n;
  vector<int> a(n);
  for (auto &i : a) {
    cin >> i;
  }
  // [l, r), 升序
  function<void(int, int)> merge_sort = [&](int l, int r) {
    if (r - l <= 1) {
      return;
    }
    int mid = l + ((r - l) >> 1);
    merge_sort(l, mid);
    merge_sort(mid, r);

    // merge
    vector<int> na;
    int i = l, j = mid;
    while (i < mid && j < r) {
      if (a[j] < a[i]) { // 先比较 a[j] < a[i],可以保证稳定性
        na.emplace_back(a[j]);
        ++j;
      } else {
        na.emplace_back(a[i]);
        ++i;
      }
    }
    while (i < mid) {
      na.emplace_back(a[i]);
      ++i;
    }
    while (j < r) {
      na.emplace_back(a[j]);
      ++j;
    }
    // swap
    for (int i = l; i < r; ++i) {
      a[i] = na[i - l];
    }
  };
  merge_sort(0, a.size());
  for (auto i : a) {
    cout << i << ' ';
  }
  cout << endl;
}

二叉堆

void solve() {
  int n;
  cin >> n;
  vector<int> a(n + 10);
  int cnt = 0;
  function<void(int)> pushup = [&](int u) {
    if (u == 1) {
      return;
    }
    int fa = u >> 1;
    if (a[u] < a[fa]) {
      swap(a[u], a[fa]);
      pushup(fa);
    }
  };
  function<void(int)> pushdown = [&](int u) {
    int v = u << 1;
    if (v > cnt) {
      return;
    }
    if ((v | 1) <= cnt && a[v | 1] < a[v]) {
      v |= 1;
    }
    if (a[v] < a[u]) {
      swap(a[v], a[u]);
      pushdown(v);
    }
  };
  auto push = [&](int x) {
    a[++cnt] = x;
    pushup(cnt);
  };
  auto pop = [&]() {
    a[1] = a[cnt--];
    pushdown(1);
  };
  for (int i = 1; i <= n; ++i) {
    int q;
    cin >> q;
    if (q == 1) {
      int x;
      cin >> x;
      push(x);
    } else if (q == 2) {
      cout << a[1] << endl;
    } else {
      pop();
    }
  }
}

对顶堆

SPOJ-RMID2 Luogu-SP16254

多组数据,不断读入整数,读入到 时输出并删除当前序列中位数( 不插入),偶数个数时输出较小的中位数,遇到 结束。

数据范围

void solve() {
  int x;
  priority_queue<int> lq;
  priority_queue<int, vector<int>, greater<int>> rq;
  // 调整对顶堆
  auto adjust = [&](int sz) {
    while (lq.size() < sz && !rq.empty()) {
      lq.push(rq.top());
      rq.pop();
    }
    while (lq.size() > sz) {
      rq.push(lq.top());
      lq.pop();
    }
  };
  // 插入新元素
  auto push = [&](int x) {
    if (lq.empty() || x < lq.top()) {
      lq.push(x);
    } else {
      rq.push(x);
    }
    int mid = (lq.size() + rq.size() + 1) / 2;
    adjust(mid);
  };
  auto pop = [&]() {
    lq.pop();
    int mid = (lq.size() + rq.size() + 1) / 2;
    adjust(mid);
  };
  while (scanf("%d", &x) != EOF) {
    if (x == 0) {
      return;
    }
    if (x == -1) {
      printf("%d\n", lq.top());
      pop();
    } else {
      push(x);
    }
  }
}

ST 表

luogu 的评测,不用 fastio 容易超时。

这里维护的是最大值。

void solve() {
  int n = read(), m = read();
  vector<int> a(n);
  vector<array<int, 20>> f(n);
  for (int i = 0; i < n; ++i) {
    f[i][0] = a[i] = read();
  }
  // init
  for (int j = 1; j < 20; ++j) {
    for (int i = 0; i < n - (1 << j) + 1; ++i) {
      f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
    }
  }
  while (m--) {
    int l = read(), r = read();
    --l, --r;
    int l2 = log2(r - l + 1);
    write(max(f[l][l2], f[r - (1 << l2) + 1][l2]));
    putchar('\n');
  }
}

并查集

简单版本

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> fa(n + 1);

  // init
  iota(fa.begin(), fa.end(), 0);

  // 找父亲
  function<int(int)> find = [&](int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
  };

  // 合并
  auto merge = [&](int x, int y) {
    fa[find(x)] = find(y);
  };

  while (m--) {
    int q, x, y;
    cin >> q >> x >> y;
    if (q == 1) {
      merge(x, y);
    } else {
      cout << ((find(x) == find(y)) ? "Y" : "N") << endl;
    }
  }
}

启发式合并

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> fa(n + 1), sz(n + 1, 1);
  iota(fa.begin(), fa.end(), 0);

  // 找父亲
  function<int(int)> find = [&](int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
  };

  // 启发式合并
  auto merge = [&](int x, int y) {
    int fx = find(x), fy = find(y);
    if (fx == fy) {
      return;
    }
    if (sz[fx] < sz[fy]) {
      swap(fx, fy);
    }
    fa[fy] = fx;
    sz[x] += sz[y];
  };

  while (m--) {
    int q, x, y;
    cin >> q >> x >> y;
    if (q == 1) {
      merge(x, y);
    } else {
      cout << ((find(x) == find(y)) ? "Y" : "N") << endl;
    }
  }
}

删除与移动

删除:将父亲设为自己,为了保证删除的元素都是叶节点,设置副本并初始化父亲为副本。 移动:保重移动的元素都在叶子节点。

实现以下功能:

  1. 合并两个元素所处集合。
  2. 移动 集合。
  3. 查询元素所在集合大小和元素和。
void solve() {
  int n, m;
  if (!(cin >> n >> m)) {
    exit(0);
  }
  vector<int> fa((n + 1) * 2), sz((n + 1) * 2, 1);
  vector<int> su((n + 1) * 2); // 添加统计大小

  // 初始化
  iota(fa.begin(), fa.begin() + n + 1, n + 1);
  iota(fa.begin() + n + 1, fa.end(), n + 1);
  iota(su.begin() + n + 1, su.end(), 0);

  // 找父亲
  function<int(int)> find = [&](int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
  };

  // 启发式合并(按大小)
  auto merge = [&](int x, int y) {
    int fx = find(x), fy = find(y);
    if (fx == fy) {
      return;
    }
    if (sz[fx] < sz[fy]) {
      swap(fx, fy);
    }
    fa[fy] = fx;
    sz[fx] += sz[fy];
    su[fx] += su[fy];
  };

  // 删除
  auto remove = [&](int x) {
    --sz[find(x)];
    fa[x] = x;
  };

  // 移动
  auto move = [&](int x, int y) {
    auto fx = find(x), fy = find(y);
    if (fx == fy) {
      return;
    }
    fa[x] = fy;
    --sz[fx], ++sz[fy];
    su[fx] -= x, su[fy] += x;
  };

  while (m--) {
    int q;
    cin >> q;
    if (q == 1) {
      int x, y;
      cin >> x >> y;
      merge(x, y);
    } else if (q == 2) {
      int x, y;
      cin >> x >> y;
      move(x, y);
    } else if (q == 3) {
      int x;
      cin >> x;
      int fx = find(x);
      cout << sz[fx] << ' ' << su[fx] << endl;
    }
  }
}

单调队列

滑动窗口最值(Luogu-P1886):

void solve() {
  int n, k;
  cin >> n >> k;
  vector<int> a(n);
  for (auto &i : a) {
    cin >> i;
  }
  // 滑动窗口最值
  auto calc = [&](function<bool(int, int)> cmp) {
    deque<int> q;
    for (int i = 0; i < k; ++i) {
      while (!q.empty() && cmp(a[i], a[q.back()])) {
        q.pop_back();
      }
      q.push_back(i);
    }
    cout << a[q.front()] << ' ';
    for (int i = k; i < n; ++i) {
      while (!q.empty() && cmp(a[i], a[q.back()])) {
        q.pop_back();
      }
      q.push_back(i);
      while (q.front() <= i - k) {
        q.pop_front();
      }
      cout << a[q.front()] << ' ';
    }
    cout << endl;
  };
  calc(less_equal<>());    // 最小值
  calc(greater_equal<>()); // 最大值
}

单调栈

求元素右侧第一个大于他的元素的下标(Luogu-P5788)。

void solve() {
  int n;
  cin >> n;
  vector<int> a(n);
  for (auto &i : a) {
    cin >> i;
  }
  stack<int> s;
  vector<int> ans(n);
  for (int i = n - 1; i >= 0; --i) {
    while (!s.empty() && a[s.top()] <= a[i]) {
      s.pop();
    }
    ans[i] = s.empty() ? -1 : s.top();
    s.push(i);
  }
  for (auto i : ans) {
    cout << i + 1 << ' ';
  }
}

树状数组

单点加,区间和。

需要运算满足结合律且可差分,如加法(和)、乘法(积)、异或等。

  • 结合律: ,其中 是一个二元运算符。
  • 可差分:具有逆运算的运算,即已知 可以求出
struct Fenwick {
  int n;
  vector<int> a;
  Fenwick(int _n): n(_n), a(_n + 10) {}
  Fenwick(const vector<int> &arr) {
    n = arr.size();
    a = vector<int>(n + 10);
    // O(n) 建树
    for (int i = 1; i <= n; ++i) {
      a[i] += arr[i - 1];
      int j = i + lowbit(i);
      if (j <= n) {
        a[j] += a[i];
      }
    }
  }
  static int lowbit(int x) {
    return x & -x;
  }
  // 单点加
  void add(int k, int x) {
    while (k <= n) {
      a[k] += x;
      k += lowbit(k);
    }
  }
  // 查询前缀和
  int query(int k) {
    int res = 0;
    while (k > 0) {
      res += a[k];
      k -= lowbit(k);
    }
    return res;
  }
  // 区间查询
  int query(int l, int r) {
    return query(r) - query(l - 1);
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> a(n);
  for (auto &x : a) {
    cin >> x;
  }
  Fenwick tree(a);
  while (m--) {
    int q;
    cin >> q;
    if (q == 1) {
      int pos, x;
      cin >> pos >> x;
      tree.add(pos, x);
    } else if (q == 2) {
      int l, r;
      cin >> l >> r;
      cout << tree.query(l, r) << endl;
    }
  }
}

区间加,单点查询。

维护差分数组即可。

区间加,区间和。

struct Tree {
private:
  vector<int> t1, t2;
  int n;

  void add_(int k, int x) {
    int v1 = k * x;
    while (k <= n) {
      t1[k] += x, t2[k] += v1;
      k += lowbit(k);
    }
  }

  int query_(vector<int> &t, int k) {
    int res = 0;
    while (k) {
      res += t[k];
      k -= lowbit(k);
    }
    return res;
  }

public:
  Tree(int _n): t1(_n + 2), t2(_n + 2), n(_n) {}

  static int lowbit(int x) {
    return x & -x;
  }

  // 区间加
  void add(int l, int r, int v) {
    add_(l, v);
    add_(r + 1, -v);
  }

  // 求区间和
  int query(int l, int r) {
    return (r + 1) * query_(t1, r) - l * query_(t1, l - 1)
           - (query_(t2, r) - query_(t2, l - 1));
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  Tree tr(n);
  for (int i = 1; i <= n; ++i) {
    int x;
    cin >> x;
    tr.add(i, i, x);
  }
  while (m--) {
    int q;
    cin >> q;
    if (q == 1) {
      int l, r, x;
      cin >> l >> r >> x;
      tr.add(l, r, x);
    } else if (q == 2) {
      int l, r;
      cin >> l >> r;
      cout << tr.query(l, r) << endl;
    }
  }
}

二维,子矩阵加,单点查询

改一改就是单点修改,子矩阵查询了。

LOJ-133

struct Tree {
  int n, m;
  vector<vector<int>> t;

  Tree(int _n, int _m): n(_n), m(_m), t(_n + 2, vector<int>(_m + 2, 0)) {}

  static int lowbit(int x) {
    return x & -x;
  }

  // 单点加
  void add(int x, int y, int v) {
    for (int i = x; i <= n; i += lowbit(i)) {
      for (int j = y; j <= m; j += lowbit(j)) {
        t[i][j] += v;
      }
    }
  }

  // 查询前缀和
  int query(int x, int y) {
    int res = 0;
    for (int i = x; i > 0; i -= lowbit(i)) {
      for (int j = y; j > 0; j -= lowbit(j)) {
        res += t[i][j];
      }
    }
    return res;
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  Tree tree(n, m);

  // 区间加,维护差分数组时使用
  auto addRange = [&](int x1, int y1, int x2, int y2, int v) {
    tree.add(x1, y1, v);
    tree.add(x1, y2 + 1, -v);
    tree.add(x2 + 1, y2 + 1, v);
    tree.add(x2 + 1, y1, -v);
  };

  // 区间查询,维护原数组时使用
  auto queryRange = [&](int x1, int y1, int x2, int y2) {
    return tree.query(x2, y2) - tree.query(x2, y1 - 1) - tree.query(x1 - 1, y2)
           + tree.query(x1 - 1, y1 - 1);
  };

  int op;

  while (cin >> op) {
    if (op == 1) {
      int x, y, k;
      cin >> x >> y >> k;
      tree.add(x, y, k);
    } else if (op == 2) {
      int x1, y1, x2, y2;
      cin >> x1 >> y1 >> x2 >> y2;
      cout << queryRange(x1, y1, x2, y2) << endl;
    }
  }
}

二维,子矩阵加,子矩阵查询

LOJ-135

struct Tree {
private:
  vector<vector<int>> t1, t2, t3, t4;
  int n, m;

  static int lowbit(int x) {
    return x & -x;
  }

  // 单点加
  void add(int x, int y, int v) {
    for (int i = x; i <= n; i += lowbit(i)) {
      for (int j = y; j <= m; j += lowbit(j)) {
        t1[i][j] += v;
        t2[i][j] += v * x;
        t3[i][j] += v * y;
        t4[i][j] += v * x * y;
      }
    }
  }

  // 查询前缀和
  int query(int x, int y) {
    int res = 0;
    for (int i = x; i > 0; i -= lowbit(i)) {
      for (int j = y; j > 0; j -= lowbit(j)) {
        res += (x + 1) * (y + 1) * t1[i][j] - (y + 1) * t2[i][j]
               - (x + 1) * t3[i][j] + t4[i][j];
      }
    }
    return res;
  }

public:
  Tree(int _n, int _m): n(_n), m(_m) {
    t1 = t2 = t3 = t4 = vector<vector<int>>(_n + 2, vector<int>(_m + 2));
  }

  // 子矩阵加
  void addRange(int x1, int y1, int x2, int y2, int v) {
    add(x1, y1, v);
    add(x1, y2 + 1, -v);
    add(x2 + 1, y1, -v);
    add(x2 + 1, y2 + 1, v);
  }

  // 子矩阵查询
  int queryRange(int x1, int y1, int x2, int y2) {
    return query(x2, y2) - query(x2, y1 - 1) - query(x1 - 1, y2)
           + query(x1 - 1, y1 - 1);
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  Tree tree(n, m);

  int op;
  while (cin >> op) {

    if (op == 1) {
      int x1, y1, x2, y2, k;
      cin >> x1 >> y1 >> x2 >> y2 >> k;
      tree.addRange(x1, y1, x2, y2, k);
    } else if (op == 2) {
      int x1, y1, x2, y2;
      cin >> x1 >> y1 >> x2 >> y2;
      cout << tree.queryRange(x1, y1, x2, y2) << endl;
    }
  }
}

权值树状数组

TODO 完善代码

维护 中出现的次数,对 建树状数组。

单点修改,求 kth:

// 权值树状数组查询第 k 小
int kth(int k) {
  int sum = 0, x = 0;
  for (int i = log2(n); ~i; --i) {
    x += 1 << i;                   // 尝试扩展
    if (x >= n || sum + t[x] >= k) // 如果扩展失败
      x -= 1 << i;
    else
      sum += t[x];
  }
  return x + 1;
}

线段树

区间加,区间求和

Luogu-P3372

struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)

  int n;
  vector<int> tag, sum;

  Tree(int _n): n(_n), tag((_n + 2) * 4), sum((_n + 2) * 4) {}
  Tree(const vector<int> &a): Tree(a.size()) {
    function<void(int, int, int)> build = [&](int l, int r, int u) {
      if (l == r) {
        sum[u] = a[l - 1];
        return;
      }
      int mid = (r - l) / 2 + l;
      build(l, mid, ls);
      build(mid + 1, r, rs);
      sum[u] = sum[ls] + sum[rs];
    };
    build(1, n, 1);
  }

  void pushdown(int u, int len) {
    tag[ls] += tag[u];
    tag[rs] += tag[u];
    sum[ls] += tag[u] * ((len + 1) >> 1);
    sum[rs] += tag[u] * (len >> 1);
    tag[u] = 0;
  }

  void pushup(int u) {
    sum[u] = sum[ls] + sum[rs];
  }

  void add(int l, int r, int x, int cl, int cr, int u) {
    int len = cr - cl + 1;
    if (cl >= l && cr <= r) {
      tag[u] += x;
      sum[u] += len * x;
      return;
    }
    if (tag[u]) {
      pushdown(u, len);
    }
    int mid = ((cr - cl) >> 1) + cl;
    if (l <= mid) {
      add(l, r, x, cl, mid, ls);
    }
    if (r > mid) {
      add(l, r, x, mid + 1, cr, rs);
    }
    pushup(u);
  }

  int query(int l, int r, int cl, int cr, int u) {
    int len = cr - cl + 1;
    if (cl >= l && cr <= r) {
      return sum[u];
    }
    if (tag[u]) {
      pushdown(u, len);
    }
    int mid = ((cr - cl) >> 1) + cl, res = 0;
    if (l <= mid) {
      res += query(l, r, cl, mid, ls);
    }
    if (r > mid) {
      res += query(l, r, mid + 1, cr, rs);
    }
    return res;
  }

  void add(int l, int r, int x) {
    return add(l, r, x, 1, n, 1);
  }

  int query(int l, int r) {
    return query(l, r, 1, n, 1);
  }

#undef ls
#undef rs
};

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> a(n);
  for (auto &x : a) {
    cin >> x;
  }
  Tree tree(a);
  while (m--) {
    int op;
    cin >> op;
    if (op == 1) {
      int l, r, x;
      cin >> l >> r >> x;
      tree.add(l, r, x);
    } else if (op == 2) {
      int l, r;
      cin >> l >> r;
      cout << tree.query(l, r) << endl;
    }
  }
}

区间加,区间最大(小)值

// 维护最大值
struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)

private:
  int n;
  vector<int> tag, ma;

  void pushdown(int u, int len) {
    tag[ls] += tag[u];
    tag[rs] += tag[u];
    ma[ls] += tag[u];
    ma[rs] += tag[u];
    tag[u] = 0;
  }

  void pushup(int u) {
    ma[u] = max(ma[ls], ma[rs]);
  }

  void add(int l, int r, int L, int R, int x, int u) {
    int len = r - l + 1;
    if (l >= L && r <= R) {
      tag[u] += x;
      ma[u] += x;
      return;
    }
    if (tag[u]) {
      pushdown(u, len);
    }
    int mid = ((r - l) >> 1) + l;
    if (L <= mid) {
      add(l, mid, L, R, x, ls);
    }
    if (R > mid) {
      add(mid + 1, r, L, R, x, rs);
    }
    pushup(u);
  }

  int query(int l, int r, int L, int R, int u) {
    int len = r - l + 1;
    if (l >= L && r <= R) {
      return ma[u];
    }
    if (tag[u]) {
      pushdown(u, len);
    }
    // 改成最小值的话别忘了改这里 res = 1e9
    int mid = ((r - l) >> 1) + l, res = 0;
    if (L <= mid) {
      res = max(res, query(l, mid, L, R, ls));
    }
    if (R > mid) {
      res = max(res, query(mid + 1, r, L, R, rs));
    }
    return res;
  }

public:
  Tree(int _n): n(_n), tag((_n + 2) * 4), ma((_n + 2) * 4) {}

  void add(int l, int r, int x) {
    return add(1, n, l, r, x, 1);
  }

  int query(int l, int r) {
    return query(1, n, l, r, 1);
  }

#undef ls
#undef rs
};

区间加,区间乘,区间求和

Luogu-P3373

int mod;

struct Tree {
#define ls (u << 1)
#define rs (u << 1 | 1)

private:
  int n;
  vector<int> sum, mu, tag;

  void pushup(int u) {
    sum[u] = (sum[ls] + sum[rs]) % mod;
  }

  void pushdown(int l, int r, int u) {
    if (mu[u] != 1) {
      mu[ls] = mu[ls] * mu[u] % mod;
      mu[rs] = mu[rs] * mu[u] % mod;
      tag[ls] = tag[ls] * mu[u] % mod;
      tag[rs] = tag[rs] * mu[u] % mod;
      sum[ls] = sum[ls] * mu[u] % mod;
      sum[rs] = sum[rs] * mu[u] % mod;
      mu[u] = 1;
    }
    int mid = ((r - l) >> 1) + l;
    if (tag[u]) {
      sum[ls] = (sum[ls] + tag[u] * (mid - l + 1)) % mod;
      sum[rs] = (sum[rs] + tag[u] * (r - mid)) % mod;
      tag[ls] = (tag[ls] + tag[u]) % mod;
      tag[rs] = (tag[rs] + tag[u]) % mod;
      tag[u] = 0;
    }
  }

  void mul(int l, int r, int L, int R, int x, int u) {
    if (l >= L && r <= R) {
      mu[u] = mu[u] * x % mod;
      tag[u] = tag[u] * x % mod;
      sum[u] = sum[u] * x % mod;
      return;
    }
    if (mu[u] != 1 || tag[u]) {
      pushdown(l, r, u);
    }
    int mid = ((r - l) >> 1) + l;
    if (mid >= L) {
      mul(l, mid, L, R, x, ls);
    }
    if (mid < R) {
      mul(mid + 1, r, L, R, x, rs);
    }
    pushup(u);
  }

  void add(int l, int r, int L, int R, int x, int u) {
    int len = r - l + 1;
    if (l >= L && r <= R) {
      sum[u] = (sum[u] + x * len % mod) % mod;
      tag[u] = (tag[u] + x) % mod;
      return;
    }
    int mid = ((r - l) >> 1) + l;
    pushdown(l, r, u);
    if (mid >= L) {
      add(l, mid, L, R, x, ls);
    }
    if (mid < R) {
      add(mid + 1, r, L, R, x, rs);
    }
    pushup(u);
  }

  int query(int l, int r, int L, int R, int u) {
    if (l >= L && r <= R) {
      return sum[u];
    }
    int mid = ((r - l) >> 1) + l, res = 0;
    pushdown(l, r, u);
    if (mid >= L) {
      res = (res + query(l, mid, L, R, ls)) % mod;
    }
    if (mid < R) {
      res = (res + query(mid + 1, r, L, R, rs)) % mod;
    }
    return res;
  }

  void build(const vector<int> &a, int l, int r, int u) {
    sum[u] = 0, tag[u] = 0, mu[u] = 1;
    if (l == r) {
      sum[u] = a[l];
      return;
    }
    int mid = ((r - l) >> 1) + l;
    build(a, l, mid, ls);
    build(a, mid + 1, r, rs);
    pushup(u);
  }

public:
  Tree(int _n): n(_n) {
    sum = mu = tag = vector<int>((_n + 2) * 4);
  }

  void add(int l, int r, int x) {
    return add(1, n, l, r, x, 1);
  }

  void mul(int l, int r, int x) {
    return mul(1, n, l, r, x, 1);
  }

  int query(int l, int r) {
    return query(1, n, l, r, 1);
  }

  void build(const vector<int> &a) {
    build(a, 1, n, 1);
  }

#undef ls
#undef rs
};

void solve() {
  int n, m;
  cin >> n >> m >> mod;
  Tree tr(n);
  vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) {
    cin >> a[i];
  }
  tr.build(a);
  while (m--) {
    int op;
    cin >> op;
    if (op == 1) {
      int l, r, x;
      cin >> l >> r >> x;
      tr.mul(l, r, x);
    } else if (op == 2) {
      int l, r, x;
      cin >> l >> r >> x;
      tr.add(l, r, x);
    } else if (op == 3) {
      int l, r;
      cin >> l >> r;
      cout << tr.query(l, r) << endl;
    }
  }
}

区间加,区间求和,动态开点

常用于权值线段树。

Luogu-P3369

普通平衡树

  1. 插入
  2. 删除 数(若有多个相同的数,应只删除一个)
  3. 查询 数的排名(排名定义为比当前数小的数的个数 )
  4. 查询排名为 的数
  5. 的前驱(前驱定义为小于 ,且最大的数)
  6. 的后继(后继定义为大于 ,且最小的数)
// 权值线段树
struct Tree {
#define ls get_ls(u)
#define rs get_rs(u)

private:
  // n 是最多元素总数,[mi, ma] 是值域
  int n;
  int mi, ma;
  vector<int> a, sum;
  vector<int> lson, rson;
  int cnt = 1;

  int get_ls(int u) {
    if (!lson[u]) {
      lson[u] = ++cnt;
    }
    return lson[u];
  }

  int get_rs(int u) {
    if (!rson[u]) {
      rson[u] = ++cnt;
    }
    return rson[u];
  }

  void pushdown(int u, int len) {
    a[ls] += a[u];
    a[rs] += a[u];
    sum[ls] += a[u] * ((len + 1) >> 1);
    sum[rs] += a[u] * (len >> 1);
    a[u] = 0;
  }

  void pushup(int u) {
    sum[u] = sum[ls] + sum[rs];
  }

  void add(int l, int r, int L, int R, int x, int u) {
    int len = r - l + 1;
    if (l >= L && r <= R) {
      a[u] += x;
      sum[u] += len * x;
      return;
    }
    if (a[u]) {
      pushdown(u, len);
    }
    int mid = ((r - l) >> 1) + l;
    if (L <= mid) {
      add(l, mid, L, R, x, ls);
    }
    if (R > mid) {
      add(mid + 1, r, L, R, x, rs);
    }
    pushup(u);
  }

  int query(int l, int r, int L, int R, int u) {
    int len = r - l + 1;
    if (l >= L && r <= R) {
      return sum[u];
    }
    if (a[u]) {
      pushdown(u, len);
    }
    int mid = ((r - l) >> 1) + l, res = 0;
    if (L <= mid) {
      res += query(l, mid, L, R, ls);
    }
    if (R > mid) {
      res += query(mid + 1, r, L, R, rs);
    }
    return res;
  }

public:
  Tree(int _n, int _mi, int _ma): n(_n), mi(_mi), ma(_ma) {
    a = sum = lson = rson = vector<int>((_n + 2) * 2);
  }

  void add(int l, int r, int x) {
    return add(mi, ma, l, r, x, 1);
  }

  int query(int l, int r) {
    return query(mi, ma, l, r, 1);
  }

#undef ls
#undef rs
};

void solve() {
  int n;
  cin >> n;
  int mi = -2e7, ma = 2e7;
  Tree tr(n * 50, mi, ma);

  // 前驱后继操作用 multiset 维护
  multiset<int> b;
  auto rank = [&](int x) {
    int l = mi, r = ma;
    while (l < r) {
      int mid = ((r - l + 1) >> 1) + l;
      int rk = tr.query(mi, mid - 1) + 1;
      if (rk > x) {
        r = mid - 1;
      } else {
        l = mid;
      }
    }
    return *b.lower_bound(l);
  };
  while (n--) {
    int op, x;
    cin >> op >> x;
    if (op == 1) {
      tr.add(x, x, 1);
      b.insert(x);
    } else if (op == 2) {
      tr.add(x, x, -1);
      b.erase(b.find(x));
    } else if (op == 3) {
      cout << tr.query(mi, x - 1) + 1 << endl;
    } else if (op == 4) {
      cout << rank(x) << endl;
    } else if (op == 5) {
      cout << *(--b.lower_bound(x)) << endl;
    } else if (op == 6) {
      cout << *b.lower_bound(x + 1) << endl;
    }
  }
}

线段树合并

TODO 完善代码

常用于权值线段树,动态开点

int merge(int u, int v, int l, int r) {
  if(!u) {
    return v;
  }
  if(!v) {
    return u;
  }
  if(u == v) {
    sum[u] += sum[v];
    return a;
  }
  int mid = ((r - l) >> 1) + l;
  ls[u] = merge(ls[u], ls[v], l, mid);
  rs[u] = merge(rs[u], rs[v], mid + 1, r);
  pushup(u);
  return u;
}

线段树分裂

TODO 完善代码

只能用于有序的序列,常用于动态开点的权值线段树

void split(int &u, int &v, int l, int r, int L, int R) {
  if(l < L || r > R) {
    return;
  }
  if(!u) {
    return;
  }
  if(l >= L && r <= R) {
    v = u;
    u = 0;
    return;
  }
  if(!q) {
    q = newNode();
  }
  int mid = ((r - l) >> 1) + l;
  if(L <= mid) {
    split(ls[u], ls[v], l, mid, L, R);
  }
  if(R > mid) {
    split(rs[u], rs[v], mid + 1, r, L, R);
  }
  pushup(u);
  pushup(v);
}

平衡树

gnu pbds 中的 tree:

Luogu-P3369

普通平衡树

  1. 插入
  2. 删除 数(若有多个相同的数,应只删除一个)
  3. 查询 数的排名(排名定义为比当前数小的数的个数 )
  4. 查询排名为 的数
  5. 的前驱(前驱定义为小于 ,且最大的数)
  6. 的后继(后继定义为大于 ,且最小的数)

__gnu_pbds::tree 不支持可重复集合,需要自己手动处理。

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <iostream>

using namespace std;
using namespace __gnu_pbds;

#define int long long

tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>
  a;

signed main() {
  int n;
  cin >> n;
  for (int i = 1; i <= n; ++i) {
    int op, x;
    cin >> op >> x;
    if (op == 1) {
      a.insert((x << 20) + i);
    } else if (op == 2) {
      a.erase(a.lower_bound(x << 20));
    } else if (op == 3) {
      cout << a.order_of_key(x << 20) + 1 << endl;
    } else if (op == 4) {
      cout << (*a.find_by_order(x - 1) >> 20) << endl;
    } else if (op == 5) {
      cout << (*(--a.lower_bound(x << 20)) >> 20) << endl;
    } else if (op == 6) {
      cout << (*a.lower_bound((x << 20) + n) >> 20) << endl;
    }
  }
  return 0;
}

std::rope 实现区间翻转:

#pragma GCC optimize(2)
#include <ext/rope>
#include <iostream>

using namespace std;
using namespace __gnu_cxx;

void solve() {
  rope<int> s, rs;
  int n, m;
  cin >> n >> m;
  for (int i = 1; i <= n; ++i) {
    s.push_back(i);
    rs.push_back(n - i + 1);
  }
  while (m--) {
    int l, r;
    cin >> l >> r;
    int rl = n - r + 1;
    int rr = n - l + 1;
    --l, --r;
    --rl, --rr;
    auto tmp = s.substr(l, r - l + 1);
    auto rtmp = rs.substr(rl, rr - rl + 1);
    s = s.substr(0, l) + rtmp + s.substr(r + 1, n - r);
    rs = rs.substr(0, rl) + tmp + rs.substr(rr + 1, n - rr);
  }
  for (auto i : s) {
    cout << i << ' ';
  }
}

signed main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  int t = 1;
  while (t--) {
    solve();
  }
}

Treap

有旋

struct Treap {
  vector<int> ls, rs, a, p, sz, cnt;
  int tot = 0, rt = 0;
  int notFound;

  Treap(int _n, int _notFound = (1u << 31) - 1): notFound(_notFound) {
    ls = rs = a = p = sz = cnt = vector<int>(_n + 10);
  }

  void pushup(int u) {
    sz[u] = sz[ls[u]] + sz[rs[u]] + cnt[u];
  }

  void rotateL(int &u) {
    int t = rs[u];
    rs[u] = ls[t];
    ls[t] = u;
    sz[t] = sz[u];
    pushup(u);
    u = t;
  }

  void rotateR(int &u) {
    int t = ls[u];
    ls[u] = rs[t];
    rs[t] = u;
    sz[t] = sz[u];
    pushup(u);
    u = t;
  }

  int newNode(int x) {
    int u = ++tot;
    sz[u] = 1;
    cnt[u] = 1;
    a[u] = x;
    p[u] = rand();
    return u;
  }

  void insert(int &u, int x) {
    if (!u) {
      u = newNode(x);
      return;
    }
    ++sz[u];
    if (a[u] == x) {
      ++cnt[u];
    } else if (a[u] < x) {
      insert(rs[u], x);
      if (p[rs[u]] < p[u]) {
        rotateL(u);
      }
    } else {
      insert(ls[u], x);
      if (p[ls[u]] < p[u]) {
        rotateR(u);
      }
    }
  }

  void insert(int x) {
    return insert(rt, x);
  }

  bool remove(int &u, int x) {
    if (!u) {
      return false;
    }
    if (a[u] == x) {
      if (cnt[u] > 1) {
        --cnt[u];
        --sz[u];
        return true;
      }
      if (ls[u] == 0 || rs[u] == 0) {
        u = ls[u] + rs[u];
        return true;
      } else if (p[ls[u]] < p[rs[u]]) {
        rotateR(u);
        return remove(u, x);
      } else {
        rotateL(u);
        return remove(u, x);
      }
    } else if (a[u] < x) {
      bool res = remove(rs[u], x);
      if (res) {
        --sz[u];
      }
      return res;
    } else {
      bool res = remove(ls[u], x);
      if (res) {
        --sz[u];
      }
      return res;
    }
  }

  bool remove(int x) {
    return remove(rt, x);
  }

  int queryRank(int u, int x) {
    if (!u) {
      return 1;
    }
    if (a[u] == x) {
      return sz[ls[u]] + 1;
    } else if (a[u] < x) {
      return sz[ls[u]] + cnt[u] + queryRank(rs[u], x);
    } else {
      return queryRank(ls[u], x);
    }
  }

  int queryRank(int x) {
    return queryRank(rt, x);
  }

  int queryByRank(int u, int x) {
    if (!u) {
      return 0;
    }
    if (x <= sz[ls[u]]) {
      return queryByRank(ls[u], x);
    } else if (x > sz[ls[u]] + cnt[u]) {
      return queryByRank(rs[u], x - sz[ls[u]] - cnt[u]);
    } else {
      return a[u];
    }
  }

  auto queryByRank(int x) {
    return queryByRank(rt, x);
  }

  int queryPrev(int u, int x) {
    int res = notFound;
    while (u) {
      if (a[u] < x) {
        res = a[u];
        u = rs[u];
      } else {
        u = ls[u];
      }
    }
    return res;
  }

  auto queryPrev(int x) {
    return queryPrev(rt, x);
  }

  int queryNext(int u, int x) {
    int res = notFound;
    while (u) {
      if (a[u] > x) {
        res = a[u];
        u = ls[u];
      } else {
        u = rs[u];
      }
    }
    return res;
  }

  auto queryNext(int x) {
    return queryNext(rt, x);
  }
};

void solve() {
  int n;
  cin >> n;
  Treap tr(n);
  while (n--) {
    int op, x;
    cin >> op >> x;
    if (op == 1) {
      tr.insert(x);
    } else if (op == 2) {
      tr.remove(x);
    } else if (op == 3) {
      cout << tr.queryRank(x) << endl;
    } else if (op == 4) {
      cout << tr.queryByRank(x) << endl;
    } else if (op == 5) {
      cout << tr.queryPrev(x) << endl;
    } else if (op == 6) {
      cout << tr.queryNext(x) << endl;
    }
  }
}

无旋

区间翻转

Luogu-P3391

struct Treap {
  vector<int> p, sz, a, tag, ls, rs;
  int tot = 0, rt = 0;

  Treap(int _n) {
    p = sz = a = tag = ls = rs = vector<int>(_n + 2);
  }

  void pushup(int u) {
    sz[u] = sz[ls[u]] + sz[rs[u]] + 1;
  }

  int newNode(int x) {
    a[++tot] = x, sz[tot] = 1, p[tot] = rand();
    return tot;
  }

  void pushdown(int u) {
    swap(ls[u], rs[u]);
    if (ls[u]) {
      tag[ls[u]] ^= 1;
    }
    if (rs[u]) {
      tag[rs[u]] ^= 1;
    }
    tag[u] = 0;
  }

  int merge(int u, int v) {
    if (!u || !v) {
      return u + v;
    }
    if (p[u] < p[v]) {
      if (tag[u]) {
        pushdown(u);
      }
      rs[u] = merge(rs[u], v);
      pushup(u);
      return u;
    }
    if (tag[v]) {
      pushdown(v);
    }
    ls[v] = merge(u, ls[v]);
    pushup(v);
    return v;
  }

  pair<int, int> split(int u, int x) {
    if (!u) {
      return {0, 0};
    }
    if (tag[u]) {
      pushdown(u);
    }
    pair<int, int> res;
    if (sz[ls[u]] < x) {
      auto tmp = split(rs[u], x - sz[ls[u]] - 1);
      rs[u] = tmp.first;
      res = {u, tmp.second};
    } else {
      auto tmp = split(ls[u], x);
      ls[u] = tmp.second;
      res = {tmp.first, u};
    }
    pushup(u);
    return res;
  }

  void dfs(int u) {
    if (!u) {
      return;
    }
    if (tag[u]) {
      pushdown(u);
    }
    dfs(ls[u]);
    cout << a[u] << ' ';
    dfs(rs[u]);
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  Treap tr(n);
  for (int i = 1; i <= n; ++i) {
    tr.rt = tr.merge(tr.rt, tr.newNode(i));
  }
  for (int i = 1; i <= m; ++i) {
    int l, r;
    cin >> l >> r;
    auto t1 = tr.split(tr.rt, l - 1);
    auto t2 = tr.split(t1.second, r - l + 1);
    tr.tag[t2.first] ^= 1;
    tr.rt = tr.merge(t1.first, tr.merge(t2.first, t2.second));
  }
  tr.dfs(tr.rt);
}

Splay

区间翻转

Luogu-P3391

struct Splay {
  vector<int> fa, ls, rs, sz, rev;
  int rt;

  Splay(int _n) {
    fa = ls = rs = sz = rev = vector<int>(_n + 10);
  }

  void pushup(int u) {
    sz[u] = sz[ls[u]] + sz[rs[u]] + 1;
  }

  void pushdown(int u) {
    if (rev[u]) {
      swap(ls[u], rs[u]);
      rev[ls[u]] ^= 1;
      rev[rs[u]] ^= 1;
      rev[u] = 0;
    }
  }

  void rotate(int u, int &v) {
    int y = fa[u], z = fa[y];
    int ca = ls[y] == u ? 1 : 0;
    if (y == v) {
      v = u;
    } else {
      if (ls[z] == y) {
        ls[z] = u;
      } else {
        rs[z] = u;
      }
    }
    if (ca == 0) {
      rs[y] = ls[u];
      fa[rs[y]] = y;
      ls[u] = y;
      fa[y] = u;
      fa[u] = z;
    } else {
      ls[y] = rs[u];
      fa[ls[y]] = y;
      rs[u] = y;
      fa[y] = u;
      fa[u] = z;
    }
    pushup(u);
    pushup(y);
  }

  void splay(int u, int &v) {
    while (u != v) {
      int y = fa[u], z = fa[y];
      if (y != v) {
        if ((ls[y] == u) ^ (ls[z] == y)) {
          rotate(u, v);
        } else {

          rotate(y, v);
        }
      }
      rotate(u, v);
    }
  }

  void build(int l, int r, int u) {
    if (l > r) {
      return;
    }
    int mid = (l + r) / 2;
    if (mid < u) {
      ls[u] = mid;
    } else {
      rs[u] = mid;
    }
    fa[mid] = u;
    sz[mid] = 1;
    if (l == r) {
      return;
    }
    build(l, mid - 1, mid);
    build(mid + 1, r, mid);
    pushup(mid);
  }

  auto build(int l, int r) {
    return build(l, r, rt);
  }

  int query(int u, int x) {
    pushdown(u);
    int s = sz[ls[u]];
    if (x == s + 1) {
      return u;
    }
    if (x <= s) {
      return query(ls[u], x);
    } else {
      return query(rs[u], x - s - 1);
    }
  }

  auto query(int x) {
    return query(rt, x);
  }

  void reverse(int l, int r) {
    int x = query(rt, l), y = query(rt, r + 2);
    splay(x, rt);
    splay(y, rs[x]);
    int z = ls[y];
    rev[z] ^= 1;
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  Splay tr(n);
  tr.rt = (n + 3) / 2;
  tr.build(1, n + 2);
  for (int i = 1; i <= m; ++i) {
    int l, r;
    cin >> l >> r;
    tr.reverse(l, r);
  }
  for (int i = 2; i <= n + 1; ++i) {
    cout << tr.query(i) - 1 << ' ';
  }
}

TODO 普通平衡树

主席树

静态区间第

Luogu-P3834

struct Tree {
  int cnt;
  vector<int> su, ls, rs;
  Tree(int n): cnt(0) {
    n = (n + 1) << 5;
    su = ls = rs = vector<int>(n);
  };
  int build(int l, int r) {
    ++cnt;
    su[cnt] = 0;
    int mid = (l + r) >> 1;
    if (l < r) {
      ls[cnt] = build(l, mid), rs[cnt] = build(mid + 1, r);
    }
    return cnt;
  }
  int update(int pre, int l, int r, int x) {
    int rt = ++cnt;
    ls[rt] = ls[pre];
    rs[rt] = rs[pre];
    su[rt] = su[pre] + 1;
    if (l < r) {
      int mid = (l + r) >> 1;
      if (x <= mid) {
        ls[rt] = update(ls[pre], l, mid, x);
      } else {
        rs[rt] = update(rs[pre], mid + 1, r, x);
      }
    }
    return rt;
  }
  int query(int a, int b, int l, int r, int k) {
    if (l >= r) {
      return l;
    }
    int t = su[ls[b]] - su[ls[a]];
    int mid = (l + r) >> 1;
    if (t >= k) {
      return query(ls[a], ls[b], l, mid, k);
    } else {
      return query(rs[a], rs[b], mid + 1, r, k - t);
    }
  }
};

void solve() {
  int n, q;
  cin >> n >> q;
  vector<int> a(n), rt(n + 1);
  Tree tr(n);
  for (auto &i : a) {
    cin >> i;
  }
  auto b = a;
  sort(b.begin(), b.end());
  b.erase(unique(b.begin(), b.end()), b.end());
  rt[0] = tr.build(1, b.size());
  auto getId = [&](int x) {
    return lower_bound(b.begin(), b.end(), x) - b.begin();
  };
  for (int i = 1; i <= n; ++i) {
    rt[i] = tr.update(rt[i - 1], 1, b.size(), getId(a[i - 1]) + 1);
  }
  while (q--) {
    int l, r, k;
    cin >> l >> r >> k;
    int pos = tr.query(rt[l - 1], rt[r], 1, b.size(), k);
    cout << b[pos - 1] << endl;
  }
}

可持久化数组

Luogu-P3919

struct Tree {
  int cnt;
  vector<int> su, ls, rs, a;

  Tree(int n): cnt(0) {
    n = (n + 1) << 5;
    su = ls = rs = a = vector<int>(n);
  };

  int build(int l, int r) {
    int rt = ++cnt;
    if (l == r) {
      su[rt] = a[l];
      return rt;
    }
    int mid = (l + r) >> 1;
    ls[rt] = build(l, mid);
    rs[rt] = build(mid + 1, r);
    return rt;
  }

  int modify(int pre, int l, int r, int x, int u) {
    int rt = ++cnt;
    ls[rt] = ls[pre];
    rs[rt] = rs[pre];
    su[rt] = su[pre];
    if (l == r) {
      su[rt] = x;
      return rt;
    }
    int mid = (l + r) >> 1;
    if (u <= mid) {
      ls[rt] = modify(ls[pre], l, mid, x, u);
    }
    if (u > mid) {
      rs[rt] = modify(rs[pre], mid + 1, r, x, u);
    }
    return rt;
  }

  int query(int rt, int l, int r, int u) {
    if (l == r) {
      return su[rt];
    }
    int mid = (l + r) >> 1;
    if (u <= mid) {
      return query(ls[rt], l, mid, u);
    } else {
      return query(rs[rt], mid + 1, r, u);
    }
  }
};

void solve() {
  int n, q;
  cin >> n >> q;
  vector<int> rt(q + 1);
  Tree tr(n);
  for (int i = 1; i <= n; ++i) {
    cin >> tr.a[i];
  }
  rt[0] = tr.build(1, n);
  for (int i = 1; i <= q; ++i) {
    int pre, op;
    cin >> pre >> op;
    if (op == 1) {
      int pos, x;
      cin >> pos >> x;
      rt[i] = tr.modify(rt[pre], 1, n, x, pos);
    } else if (op == 2) {
      int pos;
      cin >> pos;
      cout << tr.query(rt[pre], 1, n, pos) << endl;
      rt[i] = rt[pre];
    }
  }
}

可持久化并查集

TODO

多项式

FFT

Luogu-P3803-【模板】多项式乘法(FFT)

// 看不懂,当黑盒
struct FFT {
  vector<complex<double>> f;
  vector<int> rev;
  int limit = 1;
  int l = -1;
  int n, m, t;
  auto read(const vector<int> &a, const vector<int> &b) {
    n = a.size() - 1;
    m = b.size() - 1;
    t = n + m;
    n = max(n, m);
    limit = 1, l = -1;
    while (limit <= (n << 1)) {
      limit <<= 1;
      ++l;
    }
    rev.clear();
    rev.resize(limit + 1);
    for (int i = 1; i <= limit; ++i) {
      rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
    }
    f.clear();
    f.resize(limit + 1);
    for (int i = 0; i < a.size(); ++i) {
      f[i].real(a[i]);
    }
    for (int i = 0; i < b.size(); ++i) {
      f[i].imag(b[i]);
    }
  }

  void fft(int type, int limit) {
    for (int i = 1; i <= limit; ++i) {
      if (i >= rev[i]) {
        continue;
      }
      swap(f[i], f[rev[i]]);
    }
    complex<double> rt, w, x, y;
    double pi = acos(-1);
    for (int mid = 1; mid < limit; mid <<= 1) {
      int r = mid << 1;
      rt = complex<double>(cos(pi / mid), type * sin(pi / mid));
      for (int j = 0; j < limit; j += r) {
        w = complex<double>(1, 0);
        for (int k = 0; k < mid; ++k) {
          x = f[j | k];
          y = w * f[j | k | mid];
          f[j | k] = x + y;
          f[j | k | mid] = x - y;
          w = w * rt;
        }
      }
    }
    if (type == 1) {
      return;
    }
    for (int i = 0; i <= limit; ++i) {
      f[i].imag(f[i].imag() / limit);
      f[i].real(f[i].real() / limit);
    }
  }

  auto mul() {
    fft(1, limit);
    for (int i = 0; i <= limit; ++i) {
      f[i] = f[i] * f[i];
    }
    fft(-1, limit);
    vector<int> c(t + 1);
    for (int i = 0; i <= t; ++i) {
      c[i] = f[i].imag() / 2 + 0.5;
    }
    return c;
  }
};

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> a(n + 1), b(m + 1);
  for (auto &i : a) {
    cin >> i;
  }
  for (auto &i : b) {
    cin >> i;
  }
  auto fft = FFT();
  fft.read(a, b);
  auto c = fft.mul();
  for (int i = 0; i <= n + m; ++i) {
    cout << c[i] << ' ';
  }
  cout << endl;
}

背包 DP

如无特殊说明,默认 v 为价值(value),w 为重量(weight)。

01 背包

Luogu-P2871

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> v(n + 1), w(n + 1);
  vector<int> dp(m + 1);
  for (int i = 1; i <= n; ++i) {
    cin >> w[i] >> v[i];
  }
  for (int i = 1; i <= n; ++i) {
    for (int j = m; j >= w[i]; --j) {
      dp[j] = max(dp[j], dp[j - w[i]] + v[i]);
    }
  }
  cout << *max_element(dp.begin() + 1, dp.end()) << endl;
}

signed main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  cout.tie(nullptr);
  int t = 1;
  // cin >> t;
  while (t--) {
    solve();
  }
}

完全背包

Luogu-P1616

void solve() {
  int n, m;
  cin >> m >> n;
  vector<int> v(n + 1), w(n + 1);
  vector<int> dp(m + 1);
  for (int i = 1; i <= n; ++i) {
    cin >> w[i] >> v[i];
  }
  for (int i = 1; i <= n; ++i) {
    for (int j = w[i]; j <= m; ++j) {
      dp[j] = max(dp[j], dp[j - w[i]] + v[i]);
    }
  }
  cout << *max_element(dp.begin() + 1, dp.end()) << endl;
}

多重背包

AcWing-5

void solve() {
  int n, m;
  cin >> n >> m;
  vector<int> v(n + 1), w(n + 1), s(n + 1);
  for (int i = 1; i <= n; ++i) {
    cin >> w[i] >> v[i] >> s[i];
  }
  vector<int> dp(m + 1);
  auto process_01 = [&](int v, int w) {
    for (int i = m; i >= w; --i) {
      dp[i] = max(dp[i], dp[i - w] + v);
    }
  };
  for (int i = 1; i <= n; ++i) {
    int base = 1;
    int cs = s[i], cw = w[i], cv = v[i];
    while (cs > base) {
      cs -= base;
      process_01(cv * base, cw * base);
      base *= 2;
    }
    process_01(cv * cs, cw * cs);
  }
  cout << dp[m] << endl;
}

混合背包

Luogu-P1833 按类型分别套用上面三种背包的代码即可。

void solve() {
  string s, e;
  int n;
  cin >> s >> e >> n;

  auto process_time = [&](string s) {
    int h = 0, m = 0;
    int pos = 0;
    while (pos < s.size() && s[pos] != ':') {
      h *= 10;
      h += s[pos] - '0';
      ++pos;
    }
    ++pos;
    while (pos < s.size()) {
      m *= 10;
      m += s[pos] - '0';
      ++pos;
    }
    return h * 60 + m;
  };

  int m = process_time(e) - process_time(s);

  vector<int> w(n + 1), v(n + 1), a(n + 1);
  for (int i = 1; i <= n; ++i) {
    cin >> w[i] >> v[i] >> a[i];
  }
  vector<int> dp(m + 1);
  for (int i = 1; i <= n; ++i) {
    if (!a[i]) {
      for (int j = w[i]; j <= m; ++j) {
        dp[j] = max(dp[j], dp[j - w[i]] + v[i]);
      }
    } else {
      int base = 1;
      int k = a[i];
      while (k > base) {
        int cw = w[i] * base;
        int cv = v[i] * base;
        for (int j = m; j >= cw; --j) {
          dp[j] = max(dp[j], dp[j - cw] + cv);
        }
        k -= base;
        base <<= 1;
      }
      int cw = w[i] * k;
      int cv = v[i] * k;
      for (int j = m; j >= cw; --j) {
        dp[j] = max(dp[j], dp[j - cw] + cv);
      }
    }
  }
  cout << *max_element(dp.begin() + 1, dp.end()) << endl;
}

二维费用背包

Luogu-P1855

void solve() {
  int n, m, t;
  cin >> n >> m >> t;
  vector<int> w1(n + 1), w2(n + 1);
  for (int i = 1; i <= n; ++i) {
    cin >> w1[i] >> w2[i];
  }
  vector<vector<int>> dp(m + 1, vector<int>(t + 1));
  for (int i = 1; i <= n; ++i) {
    for (int j = m; j >= w1[i]; --j) {
      for (int k = t; k >= w2[i]; --k) {
        dp[j][k] = max(dp[j][k], dp[j - w1[i]][k - w2[i]] + 1);
      }
    }
  }
  cout << dp[m][t] << endl;
}

分组背包

Luogu-P1757

件物品和一个大小为 的背包,第 个物品的价值为 , 体积为 。同时,每个物品属于一个组,同组内最多只能选择一个物品。 求背包能装载物品的最大总价值。

void solve() {
  int m, n;
  cin >> m >> n;
  unordered_map<int, vector<pair<int, int>>> a;
  for (int i = 1; i <= n; ++i) {
    int w, v, k;
    cin >> w >> v >> k;
    a[k].emplace_back(w, v);
  }
  vector<int> dp(m + 1);
  for (auto &[_, vs] : a) {
    for (int i = m; i >= 0; --i) {
      for (auto &[w, v] : vs) {
        if (i >= w) {
          dp[i] = max(dp[i], dp[i - w] + v);
        }
      }
    }
  }
  cout << *max_element(dp.begin() + 1, dp.end());
}

可以转化成分组背包:有依赖的背包,把物品依赖的选择方案分到同一组。

背包问题变种

输出方案

表示第 件物品占用空间为 的时候是否被选择,转移时记录选或不选,输出:

int cur_w = m;
vector<int> selected;

for (int i = n; i >= 1; --i) {
  if (g[i][cur_w]) {
    selected.emplace_back(i);
    cur_w -= w[i];
  }
}

求方案数量

把转移中求最大值变为求和。

01 背包:

求最优方案数量

TODO 重写并测试

for (int i = 0; i < N; i++) {
  for (int j = V; j >= v[i]; j--) {
    int tmp = std::max(dp[j], dp[j - v[i]] + w[i]);
    int c = 0;
    if (tmp == dp[j]) {
      c += cnt[j]; // 如果从dp[j]转移
    }
    if (tmp == dp[j - v[i]] + w[i]) {
      c += cnt[j - v[i]]; // 如果从dp[j-v[i]]转移
    }
    dp[j] = tmp;
    cnt[j] = c;
  }
}
int max = 0; // 寻找最优解
for (int i = 0; i <= V; i++) {
  max = std::max(max, dp[i]);
}
int res = 0;
for (int i = 0; i <= V; i++) {
  if (dp[i] == max) {
    res += cnt[i]; // 求和最优解方案数
  }
}

求第 k 优解

TODO 重写并测试

memset(dp, 0, sizeof(dp));
int i, j, p, x, y, z;
scanf("%d%d%d", &n, &m, &K);
for (i = 0; i < n; i++) {
  scanf("%d", &w[i]);
}
for (i = 0; i < n; i++) {
  scanf("%d", &c[i]);
}
for (i = 0; i < n; i++) {
  for (j = m; j >= c[i]; j--) {
    for (p = 1; p <= K; p++) {
      a[p] = dp[j - c[i]][p] + w[i];
      b[p] = dp[j][p];
    }
    a[p] = b[p] = -1;
    x = y = z = 1;
    while (z <= K && (a[x] != -1 || b[y] != -1)) {
      if (a[x] > b[y]) {
        dp[j][z] = a[x++];
      } else {
        dp[j][z] = b[y++];
      }
      if (dp[j][z] != dp[j][z - 1]) {
        z++;
      }
    }
  }
}
printf("%d\n", dp[m][K]);

状压 DP

枚举子集的子集,时间复杂度

for (int i = 0; i < (1 << n); ++i) {
  for (int j = i; j; j = (j - 1) & i) {
    // j 为 i 的子集
  }
}

DP 优化

四边形不等式优化

形如以下转移方程:

需满足的条件:

  • 区间包含单调性:若 ,则
  • 四边形不等式:若 ,则

结论:

  • 也满足四边形不等式。
  • 假设 的最优决策点,那么

只需要记录 ,可以优化掉一维循环。

考场上可以直接打表记录 后验证单调性

满足四边形不等式的函数类:

性质 1:若函数 均满足四边形不等式(或区间包含单调性),则对于任意 ,函数 也满足四边形不等式(或区间包含单调性)。

性质 2:若存在函数 使得 ,则函数 满足四边形恒等式。当函数 单调增加时,函数 还满足区间包含单调性。

性质 3:设 是一个单调增加的凸函数,若函数 满足四边形不等式并且对区间包含关系具有单调性,则复合函数 也满足四边形不等式和区间包含单调性。

性质 4:设 是一个凸函数,若函数 满足四边形恒等式并且对区间包含关系具有单调性,则复合函数 也满足四边形不等式。

首先需要澄清一点,凸函数(Convex Function)的定义在国内教材中有分歧,此处的凸函数指的是下凸函数,即(可微时)一阶导数单调增加的函数。

CF321E

void solve() {
  int n, m;
  io.read(n);
  io.read(m);
  vector<vector<int>> a(n + 10, vector<int>(n + 10));
  for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= n; ++j) {
      io.read(a[i][j]);
    }
  }
  auto p = a;
  for (int i = 1; i <= n; ++i) {
    for (int j = 1; j <= n; ++j) {
      p[i][j] += p[i - 1][j] + p[i][j - 1] - p[i - 1][j - 1];
    }
  }
  auto get_sum = [&](int l, int r) {
    l = max(l, 1ll);
    if (r < l) {
      return 0ll;
    }
    int res = p[r][r] - p[l - 1][r] - p[r][l - 1] + p[l - 1][l - 1];
    return res;
  };
  vector<vector<int>> dp(n + 10, vector<int>(m + 10, 1e12));
  vector<vector<int>> bt(n + 10, vector<int>(m + 10));
  for (int i = 0; i <= m; ++i) {
    dp[i][i] = 0;
    bt[n + 1][i] = n;
  }
  for (int i = 0; i <= n; ++i) {
    bt[i][0] = 1;
  }
  for (int i = 0; i <= n; ++i) {}
  for (int j = 1; j <= m; ++j) {
    for (int i = n; i >= j + 1; --i) {
      for (int k = bt[i][j - 1]; k <= bt[i + 1][j]; ++k) {
        int cur = dp[k - 1][j - 1] + get_sum(k, i) / 2;
        if (cur < dp[i][j]) {
          dp[i][j] = cur;
          bt[i][j] = k;
        }
      }
    }
  }
  io.write(dp[n][m]);
}

虚树

Luogu-P2495 [SDOI2011] 消耗战

给一个有边权的有根树,每次询问给一堆点,求割掉一些边使得根节点无法到达。

点数量级 ,询问点之和数量级

对每次询问的点建虚树,跑个 dp。核心是建树。

这份代码常数极大。

void solve() {
  int n;
  cin >> n;
  vector<vector<pair<int, int>>> e(n + 1);
  for (int i = 1; i <= n - 1; ++i) {
    int u, v, w;
    cin >> u >> v >> w;
    e[u].emplace_back(v, w);
    e[v].emplace_back(u, w);
  }
  int cdfn = 0;
  vector<int> dfn(n + 1);
  vector<vector<int>> fa(n + 1, vector<int>(23));
  auto fad = fa;
  vector<int> dep(n + 1);
  function<void(int, int, int)> dfs_init = [&](int u, int pre, int w) {
    // 记录 dfs 序
    dfn[u] = ++cdfn;
    // 记录倍增父节点
    fa[u][0] = pre;
    // 记录到倍增父节点的边的最小值
    fad[u][0] = w;
    // 记录深度
    dep[u] = dep[pre] + 1;
    // 倍增处理
    for (int i = 1; i <= 20; ++i) {
      fa[u][i] = fa[fa[u][i - 1]][i - 1];
      fad[u][i] = min(fad[u][i - 1], fad[fa[u][i - 1]][i - 1]);
    }
    for (auto [v, w] : e[u]) {
      if (v == pre) {
        continue;
      }
      dfs_init(v, u, w);
    }
  };
  dfs_init(1, 0, 0);
  // 求 LCA
  auto lca = [&](int u, int v) {
    if (dep[u] > dep[v]) {
      swap(u, v);
    }
    int k = dep[v] - dep[u];
    for (int j = 0; k; ++j, k >>= 1) {
      if (k & 1) {
        v = fa[v][j];
      }
    }
    if (v == u) {
      return u;
    }
    for (int j = 20; j >= 0 && v != u; --j) {
      if (fa[u][j] != fa[v][j]) {
        u = fa[u][j];
        v = fa[v][j];
      }
    }
    return fa[u][0];
  };
  // 获取从节点 u 到他的第 k 级祖先路径上权值最小的边
  auto get_fad = [&](int u, int k) {
    int ans = 1e9;
    for (int i = 0; k; ++i, k >>= 1) {
      if (k & 1) {
        ans = min(ans, fad[u][i]);
        u = fa[u][i];
      }
    }
    return ans;
  };
  // 获取从 u 到 v 路径上权值最小的边
  auto get_dis = [&](int u, int v) {
    int lc = lca(u, v);
    return min(get_fad(u, dep[u] - dep[lc]), get_fad(v, dep[v] - dep[lc]));
  };
  int q;
  cin >> q;
  while (q--) {
    int k;
    cin >> k;
    vector<int> h(k + 1);
    // 把根节点插进去,便于后续处理
    h[0] = 1;
    for (int i = 1; i <= k; ++i) {
      cin >> h[i];
    }
    // 第一次按 dfs 序排序
    sort(h.begin() + 1, h.begin() + k + 1, [&](int a, int b) {
      return dfn[a] < dfn[b];
    });
    vector<int> a = {0};
    // 求 dfs 序相邻的节点的 lca,加入虚树点集中
    for (int i = 0; i < k; ++i) {
      a.emplace_back(h[i]);
      a.emplace_back(lca(h[i], h[i + 1]));
    }
    a.emplace_back(h[k]);
    // 第二次按 dfs 序排序
    sort(a.begin() + 1, a.end(), [&](int a, int b) {
      return dfn[a] < dfn[b];
    });
    // 去重
    a.erase(unique(a.begin() + 1, a.end()), a.end());
    int m = a.size() - 1;
    vector<vector<pair<int, int>>> ce(m + 1);
    // 重新分配 id,防止后面数组开大了复杂度退化
    map<int, int> id;
    int cid = 1;
    // 保证根节点还在 1 号
    id[1] = cid;
    auto get_id = [&](int u) {
      if (id.find(u) != id.end()) {
        return id[u];
      }
      return id[u] = ++cid;
    };
    // 虚树连边
    for (int i = 1; i < m; ++i) {
      int u = lca(a[i], a[i + 1]);
      int v = a[i + 1];
      ce[get_id(u)].emplace_back(get_id(v), get_dis(u, v));
      ce[get_id(v)].emplace_back(get_id(u), get_dis(v, u));
    }
    // 标记关键点
    vector<int> b(m + 1);
    for (int i = 1; i <= k; ++i) {
      b[get_id(h[i])] = 1;
    }
    // 问题求解的 DP
    vector<int> dp(m + 1);
    function<void(int, int)> dfs = [&](int u, int pre) {
      for (auto [v, w] : ce[u]) {
        if (v == pre) {
          continue;
        }
        dfs(v, u);
        if (!b[v]) {
          dp[u] += min(dp[v], w);
        } else {
          dp[u] += w;
        }
      }
    };
    dfs(1, 0);
    cout << dp[1] << endl;
  }
}

Manacher

计算出所有的回文子串。
代码来自 OI-Wiki

计算 d1,即长度为奇数的回文子串的中心:

vector<int> d1(n);
for (int i = 0, l = 0, r = -1; i < n; i++) {
  int k = (i > r) ? 1 : min(d1[l + r - i], r - i + 1);
  while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) {
    k++;
  }
  d1[i] = k--;
  if (i + k > r) {
    l = i - k;
    r = i + k;
  }
}

计算 d2,即长度为偶数的回文子串的中心:

vector<int> d2(n);
for (int i = 0, l = 0, r = -1; i < n; i++) {
  int k = (i > r) ? 0 : min(d2[l + r - i + 1], r - i + 1);
  while (0 <= i - k - 1 && i + k < n && s[i - k - 1] == s[i + k]) {
    k++;
  }
  d2[i] = k--;
  if (i + k > r) {
    l = i - k - 1;
    r = i + k;
  }
}

自带取模类型

自带取模 int

template <int mod>
struct ModInt {
  int x;

  ModInt(): x(0) {}

  ModInt(int y): x(y >= 0 ? y : y + mod) {}

  ModInt(ll y): x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}

  inline int inc(const int &v) {
    return v >= mod ? v - mod : v;
  }

  inline int dec(const int &v) {
    return v < 0 ? v + mod : v;
  }

  inline ModInt &operator+=(const ModInt &p) {
    x = inc(x + p.x);
    return *this;
  }

  inline ModInt &operator-=(const ModInt &p) {
    x = dec(x - p.x);
    return *this;
  }

  inline ModInt &operator*=(const ModInt &p) {
    x = (int)((ll)x * p.x % mod);
    return *this;
  }

  inline ModInt inverse() const {
    int a = x, b = mod, u = 1, v = 0, t;
    while (b > 0) {
      t = a / b, std::swap(a -= t * b, b), std::swap(u -= t * v, v);
    }
    return u;
  }

  inline ModInt &operator/=(const ModInt &p) {
    *this *= p.inverse();
    return *this;
  }

  inline ModInt operator-() const {
    return -x;
  }

  inline friend ModInt operator+(const ModInt &lhs, const ModInt &rhs) {
    return ModInt(lhs) += rhs;
  }

  inline friend ModInt operator-(const ModInt &lhs, const ModInt &rhs) {
    return ModInt(lhs) -= rhs;
  }

  inline friend ModInt operator*(const ModInt &lhs, const ModInt &rhs) {
    return ModInt(lhs) *= rhs;
  }

  inline friend ModInt operator/(const ModInt &lhs, const ModInt &rhs) {
    return ModInt(lhs) /= rhs;
  }

  inline bool operator==(const ModInt &p) const {
    return x == p.x;
  }

  inline bool operator!=(const ModInt &p) const {
    return x != p.x;
  }

  inline ModInt qpow(ll n) const {
    ModInt ret(1), mul(x);
    while (n > 0) {
      if (n & 1) {
        ret *= mul;
      }
      mul *= mul, n >>= 1;
    }
    return ret;
  }

  inline friend std::ostream &operator<<(std::ostream &os, const ModInt &p) {
    return os << p.x;
  }

  inline friend std::istream &operator>>(std::istream &is, ModInt &a) {
    ll t;
    is >> t, a = ModInt<mod>(t);
    return is;
  }

  static int get_mod() {
    return mod;
  }

  inline bool operator<(const ModInt &A) const {
    return x < A.x;
  }

  inline bool operator>(const ModInt &A) const {
    return x > A.x;
  }
};