RMQ 问题展开目录
RMQ(Range Minimum/Maximum Query),即区间最值查询。对于长度为 n 的数列 arr,回答若干询问 Q (i,j),返回数列 arr 中下标在 i,j 之间的最大 / 小值。如果只有一次询问,那一遍 for 就可以搞定,但是如果有多次询问就无法在很快的时间处理出来。
ST 算法展开目录
ST 算法是一个在线算法,它可以在 O (nlogn) 时间内进行预处理,然后在 O (1) 的时间内回答每个查询,假设现在的数组为 arr [] = {1,3,6,7,4,2,5,9},算法步骤如下:
一、预处理(以处理区间最小值为例)展开目录
$dp [i][j]$ 表示从第 i 位开始连续 $2^j$ 个数(也就是到 $i+2^j-1$)中的最小值。例如 $dp [2][1]$ 表示从第 2 个数开始,连续 2 个数的最小值,即 3,6 之间的最小值,即 $dp [2][1]=3$,从 dp 数组的含义我们就知道,$dp [i][0]=arr [i]$(下标均是从 1 开始),初值有了,剩下的就是状态转移方程。首先把 $dp [i][j]$ 平均分成两段(因为一定是偶数个数字),从 i 到 $i+2^{j-1}-1$ 为一段,$i+2 {j-1}$ 到 $i+2^j-1$ 为一段(每段长度都为 $2^{j-1}$)。假设 i=1,j=3 时就是 1,3,6,7 和 4,2,5,9 这两段。$dp [i][j]$ 就是这两段最大值的最大值。于是得到了状态转移方程式 $dp[i][j] = max(dp[i][j-1],dp[i+2^{j-1}][j-1])$
- for(int i = 1;i <= n;i++)
- dp[i][0] = arr[i];
- for(int j = 1;(1 << j) <= n;j++)
- for(int i = 1;i + (1 << j) - 1 <= n;i++)
- dp[i][j] = Math.min(dp[i][j-1],dp[i + (1<<(j - 1))][j-1]);
二、查询展开目录
假设我们需要查询区间 [L,R] 中的最小值,令 $k=log_2 (R-L+1)$,则区间 [L,R] 的最小值 $res=min (dp [L][k],dp [R-(1<<k)+1][k])$,为什么这样就可以保证区间最值?dpL 维护的是 $[L,L+2^k-1]$,$dp [L][R - 2^k+1][k]$ 维护的是 $[R-2^k+1,R]$,因此只要证明 $R-2^k+1 ≤ l+2^k-1$ 即可,这里证明省略
- int k = (int) (Math.log(r - l + 1) / Math.log(2));
- int min = Math.min(dp_min[l][k],dp_min[r - (1 << k) + 1][k]);
举个栗子展开目录
$L=4,R=6$,此时 $k=log_2 (R-L+1)=log_23=1$,所以 $RMQ (4,6)=min (dp [4][1],dp [5][1])=min (4,2)=2$,很容易看出来答案是正确的
题目链接:POJ3264展开目录
ST 算法板子题,用 java 的同学要注意的就是把你所有会的输入输出优化全用上,不然会 TLE
- import java.io.InputStreamReader;
- import java.util.Scanner;
-
- public class CF522A {
- final static int N = 50005;
- static int[][] dp_min = new int[N][25];
- static int[][] dp_max = new int[N][25];
-
- public static void main(String[] args) {
- Scanner cin = new Scanner(new InputStreamReader(System.in));
- int n = Integer.parseInt(cin.next());
- int m = Integer.parseInt(cin.next());
- for(int i = 1;i <= n;i++) {
- int tmp = cin.nextInt();
- dp_min[i][0] = tmp;
- dp_max[i][0] = tmp;
- }
- //预处理
- for(int j = 1;(1 << j) <= n;j++)
- for(int i = 1;i + (1 << j) <= n + 1;i++) {
- dp_min[i][j] = Math.min(dp_min[i][j - 1],dp_min[i + (1 << j - 1)][j - 1]);//加减优先级高于位运算
- dp_max[i][j] = Math.max(dp_max[i][j - 1],dp_max[i + (1 << j - 1)][j - 1]);
- }
-
- while((m--) != 0) {
- int l = Integer.parseInt(cin.next());
- int r = Integer.parseInt(cin.next());
- int k = (int) (Math.log(r - l + 1) / Math.log(2));
- int min = Math.min(dp_min[l][k],dp_min[r - (1 << k) + 1][k]);
- int max = Math.max(dp_max[l][k],dp_max[r - (1 << k) + 1][k]);
- System.out.println(max - min);
- }
- }
- }
LCA展开目录
求 LCA(最近公共祖先)的算法有好多种,按在线和离线分为在线算法和离线算法,离线算法有基于搜索的 Tarjan 算法,而在线算法则是基于 DP 的 ST 算法。首先给定一棵树通过深搜,可以得到这样的一个序列:
数组下标:1 2 3 4 5 6 7 8 9 10 11 12 13
遍历顺序: A B D B E F E G E B A C A
结点在树中的深度:1 2 3 2 3 4 3 4 3 2 1 2 1
要查询 D 和 G 的 LCA:
- 在遍历序列中找到 D 和 G 第一次出现的位置,first [D]=3,first [G]=8(3,8 指数组下标)
- 取深度数组的 [3,8] 那一段序列,查询一个最小值 min (3,2,3,4,3,4)=2,对应遍历数组中的结点是 B,所以 D,G 的 LCA 是 B
- #include<cstring>
- #include<iostream>
- using namespace std;
- int n,m,s,tot = 0,cnt = 0;
- //vis[i]:dfs第i个访问的结点
- //r[i]:vis[i]所在的层数
- //fir[i]:vis[i]第一次出现的下标
- int head[1000100],nxt[1000100],to[1000100];
- int fir[1000100],vis[1000100],r[1000100];
- int f[20][1000100],rec[20][1000100];
- void addEdge(int x,int y) {
- cnt++;
- nxt[cnt] = head[x];
- head[x] = cnt;
- to[cnt] = y;
- }
- void dfs(int u,int dep) {//dfs处理出三个数组
- fir[u] = ++tot,vis[tot] = u,r[tot] = dep;
- for(int i = head[u];i != -1;i = nxt[i]) {
- int v = to[i];
- if(!fir[v]) {
- dfs(v,dep + 1);
- vis[++tot] = u,r[tot] = dep;
- }
- }
- }
- int main() {
- memset(head,-1,sizeof(head));
- scanf("%d%d%d",&n,&m,&s);
- for(int i = 1;i < n;i++) {
- int x,y;
- scanf("%d%d",&x,&y);
- addEdge(x,y);
- addEdge(y,x);
- }
- dfs(s,1);
- //ST表求RMQ
- for(int i = 1;i <= tot;i++)
- f[0][i] = r[i],rec[0][i] = vis[i];
- for(int i = 1;(1 << i) <= tot;i++)
- for(int j = 1;j + (1 << i) <= tot + 1;j++)
- if(f[i - 1][j] < f[i - 1][j + (1 << (i - 1))])
- f[i][j] = f[i - 1][j],rec[i][j] = rec[i - 1][j];
- else
- f[i][j] = f[i - 1][j + (1 << (i - 1))],rec[i][j] = rec[i - 1][j + (1 << (i - 1))];
- //rec记录的是区间内深度最小值的编号
- for(int i = 1;i <= m;i++) {
- int l,r,k = 0;
- scanf("%d%d",&l,&r);
- l = fir[l],r = fir[r];
- if(l > r)
- swap(l,r);
- while((1 << k) <= r - l + 1)
- k++;
- k--;
- if(f[k][l] < f[k][r - (1 << k) + 1])
- printf("%d\n",rec[k][l]);
- else
- printf("%d\n",rec[k][r - (1 << k) + 1]);
- }
- return 0;
- }