MENU

线段树例题

January 10, 2019 • Read: 8838 • 算法阅读设置

HDU 1754展开目录

线段树板子题,结构体存保存的是左右端点以及这个区间内的最大值

建树的过程其实是一个后序遍历的过程,先建立左子树然后右子树,根节点的最大值就是 max(左孩子的最大值,右孩子的最大值)

修改类似,更新当前节点的最大值为 max(value, 当前节点的最大值)

查询的时候用一个全局变量记录,每次查找到对应区间,就更新全局变量 res = max(res, 当前区间的最大值)

Java 会爆空间,只能用 c 做

  • #include <stdio.h>
  • #define Max 200001
  • struct data {
  • int l, r, max;
  • }node[4 * Max];
  • int score[Max];
  • int res;
  • int max(int a, int b) {
  • return a > b ? a : b;
  • }
  • void make(int l, int r, int idx) {
  • node[idx].l = l;
  • node[idx].r = r;
  • if (l == r)
  • node[idx].max = score[l];
  • else {
  • make(l, (l + r) >> 1, (idx << 1) + 1);
  • make(((l + r) >> 1) + 1, r, (idx << 1) + 2);
  • node[idx].max = max(node[(idx << 1) + 1].max, node[(idx << 1) + 2].max);
  • }
  • }
  • void update(int i, int value, int idx) {
  • node[idx].max = max(value, node[idx].max);
  • if (node[idx].l == node[idx].r)
  • return;
  • if (i <= (node[idx].l + node[idx].r) >> 1) // 左子树
  • update(i, value, (idx << 1) + 1);
  • else // 右子树
  • update(i, value, (idx << 1) + 2);
  • }
  • void query(int l, int r, int idx) {
  • if (l <= node[idx].l && r >= node[idx].r)
  • res = max(res, node[idx].max);
  • else {
  • int mid = (node[idx].l + node[idx].r) >> 1;
  • if (r <= mid)
  • query(l, r, (idx << 1) + 1);
  • else if (l > mid)
  • query(l, r, (idx << 1) + 2);
  • else {
  • query(l, r, (idx << 1) + 1);
  • query(l, r, (idx << 1) + 2);
  • }
  • }
  • }
  • int main() {
  • int N,M;
  • while(scanf("%d%d",&N,&M) != EOF) {
  • for (int i = 1; i <= N; i++)
  • scanf("%d",&score[i]);
  • getchar();
  • char c;
  • int s,e;
  • make(1,N,0);
  • for (int i = 0; i < M; i++) {
  • scanf("%c%d%d",&c,&s,&e);
  • getchar();
  • if (c == 'U')
  • update(s,e,0);
  • if (c == 'Q') {
  • res = -1;
  • query(s,e,0);
  • printf("%d\n",res);
  • }
  • }
  • }
  • return 0;
  • }

POJ 3468展开目录

题目大意是说给你一个数列 A,有两种操作,一是将第 a 到 b 的数都加上 c,二是询问 a 到 b 之间所有数的和

这是线段树的一种进阶应用 —— 区间更新,区间查询

区间更新是指更新某个区间内的叶子节点的值,因为涉及到的叶子节点不止一个,而叶子节点会影响其非叶的父节点,那么回溯需要更新的非叶节点也会有很多,如果一次性更新完,操作的时间复杂度肯定不止 O (logn),例如当我们更新区间 [1,7] 内的值,就需要更新下图所示标红的所有节点
为此引入线段树的延迟标记概念,也叫 lazy tag

延迟标记:节点结构体中新增一个标记,记录这个节点是否会进行某种修改,对于任意区间的修改,我们先按照区间查询的方式将其划分成线段树中的节点,然后修改这些节点的信息,并给这些节点打上标记。在修改和查询的时候,如果我们到了一个节点 P,并且要继续查看其子节点,那么我们就要看看节点 P 是否被标记,如果有,则需要按照其标记首先修改子节点的信息,并且给子节点都打上相同的标记,同时取消节点 P 的标记,这一操作称为标记下放,也叫 pushDown

可以这么理解,假设爷爷要给两个孙女压岁钱,所以爷爷就先把总的压岁钱给自己的儿子,让儿子给女儿
,但是儿子觉得自己的女儿还太小了,暂时用不到,于是就先保存着。突然有一天爷爷准备要问孙女拿到压岁钱了没有,此时爸爸着急了,就赶紧把压岁钱给了女儿

具体在 update 函数中的操作就是,如果当前更新的区间为 [l,r],我走到节点 P 对应的区间是 [curl,curr],如果 $[curl,curr] \in [l,r]$,那就先更新当前节点 P,然后给 P 打上标记,P 的子节点就不管了,直接 return,如果以后进行查询或者更新操作的时候,发现当前节点有标记,才将标记下放

除了 pushDown,还需要 pushUp,pushDown 的作用是将标记下放,而 pushUp 的作用是更新根节点的信息,因为子节点值改变了,根节点也会变,所以必须要更新根节点的信息

  • import java.io.InputStreamReader;
  • import java.util.Scanner;
  • public class Main {
  • static Node[] n;
  • static long[] t;
  • static long SUM;
  • public static void main(String[] args) {
  • Scanner cin = new Scanner(new InputStreamReader(System.in));
  • int N = cin.nextInt();
  • int M = cin.nextInt();
  • n = new Node[4 * N];
  • t = new long[N + 1];
  • for (int i = 1; i <= N; i++)
  • t[i] = cin.nextLong();
  • make(1, N, 0);
  • for (int i = 0; i < M; i++) {
  • String op = cin.next();
  • int a = cin.nextInt();
  • int b = cin.nextInt();
  • if ("Q".equals(op)) {
  • SUM = 0;
  • query(a, b, 0);
  • System.out.println(SUM);
  • } else { // "C".equals(op)
  • int c = cin.nextInt();
  • update(a, b, c, 0);
  • }
  • }
  • }
  • static void update(int l, int r, int c, int idx) {
  • if (l <= n[idx].l && r >= n[idx].r) {
  • n[idx].sum += (n[idx].r - n[idx].l + 1) * c;
  • n[idx].inc += c;
  • return;
  • }
  • if (n[idx].inc != 0)
  • pushDown(idx);
  • int mid = (n[idx].l + n[idx].r) >> 1;
  • if (r <= mid)
  • update(l, r, c, (idx << 1) | 1);
  • else if (l > mid)
  • update(l, r, c, (idx << 1) + 2);
  • else {
  • update(l, r, c, (idx << 1) | 1);
  • update(l, r, c, (idx << 1) + 2);
  • }
  • pushUp(idx);
  • }
  • static void pushDown(int idx) {
  • int mid = (n[idx].l + n[idx].r) >> 1;
  • n[(idx << 1) | 1].sum += (mid - n[idx].l + 1) * n[idx].inc;
  • n[(idx << 1) + 2].sum += (n[idx].r - mid) * n[idx].inc;
  • n[(idx << 1) | 1].inc += n[idx].inc;
  • n[(idx << 1) + 2].inc += n[idx].inc;
  • n[idx].inc = 0;
  • }
  • static void query(int l, int r, int idx) {
  • if (l <= n[idx].l && r >= n[idx].r)
  • SUM += n[idx].sum;
  • else {
  • if (n[idx].inc != 0)
  • pushDown(idx);
  • int mid = (n[idx].l + n[idx].r) >> 1;
  • if (r <= mid)
  • query(l, r, (idx << 1) | 1);
  • else if (l > mid)
  • query(l, r, (idx << 1) + 2);
  • else {
  • query(l, r, (idx << 1) | 1);
  • query(l, r, (idx << 1) + 2);
  • }
  • }
  • }
  • static void make(int l, int r, int idx) {
  • n[idx] = new Node();
  • n[idx].l = l;
  • n[idx].r = r;
  • if (l == r)
  • n[idx].sum = t[r];
  • else {
  • make(l, (l + r) >> 1, (idx << 1) | 1); // 左子树
  • make(((l + r) >> 1) + 1, r, (idx << 1) + 2); // 右子树
  • pushUp(idx);
  • }
  • }
  • static void pushUp(int idx) {
  • n[idx].sum = n[(idx << 1) | 1].sum + n[(idx << 1) + 2].sum;
  • }
  • }
  • class Node {
  • int l, r;
  • long sum;
  • long inc;
  • }

类似的线段树区间更新题目还有 CDOJ 秋实大哥与花

POJ 2528展开目录

题目大意,有一面墙,被等分成 1 千万份。现在往墙上贴 N 张海报,每张海报的宽度是任意的(必定是整数,且小于 1 千万)。后贴的海报若与先贴的海报有交集,后贴的海报就会全部或局部覆盖先贴的海报,现在给出每张海报所贴的位置(左端点和右端点),问贴完 N 张海报后,还能看见多少张海报(看见一部分也算看到)

这是一道区间压缩映射(离散化)线段树问题,首先抽象问题:给定一条数轴,长度为 1 千万,然后在数轴上的某些区间贴海报,第 i 次对区间贴的海报为 i,给出每次贴海报的区间,问最后能看见都少张海报

这道题单纯用线段树去求解需要建立一棵 [1,1 千万] 的线段树,MLE 是铁定的,所以必须离散化

通俗点说,离散化就是压缩区间,使原有的长区间映射到新的短区间,但是区间压缩前后的覆盖关系不变。举个例子:

有一条 1 到 10 的数轴,长度为 9,给定 4 个区间 [2,4] [3,6] [8,10] [6,9],后者覆盖前者,每个区间编号依次为 1 2 3 4

现在我们抽取这 4 个区间的 8 个端点 2 4 3 6 8 10 6 9,删除重复的端点,对其升序排序得 2 3 4 6 8 9 10,然后建立映射

23468910
1234567

那么新的 4 个区间为 [1,3] [2,4] [5,7] [4,6],覆盖关系也没有改变,新数轴为 1 到 7

这就完了吗,这样做真的对吗?考虑一组数据,假如三张海报的区间为 [1,10] [1,4] [6,10],离散化后 x [1]=1,x [2]=4,x [3]=6,x [4]=10

放第一张海报时:墙的 1-4 标记为 1
放第二张海报时:墙的 1-2 被标记为 2,3-4 仍为 1
放第三张海报时:墙的 3-4 被标记为 3,1-2 仍为 2

最终,第一张海报就被完全覆盖了,于是输出 2,但实际上正确输出应该为 3

正确的离散方法是:在相差大于 1 的数间加一个数,例如在上面 1 4 6 10 中间加 5

x[1]=1,x[2]=4,x[3]=5,x[4]=6,x[5]=10

这样之后,第一次是 1-5 变成 1,第二次 1-2 变成 2,第三次 4-5 变成 3

  • #include <iostream>
  • #include <algorithm>
  • #include <math.h>
  • using namespace std;
  • int n;
  • struct CPost { // 海报
  • int l,r;
  • } posters[10100];
  • int x[20200]; // 海报的端点瓷砖编号
  • int hashArr[10000010]; // hashArr[i]表示瓷砖i所处的离散化后区间编号
  • struct CNode {
  • int l,r;
  • bool bCovered; // 区间[l,r]是否被完全覆盖
  • CNode *pLeft, *pRight;
  • } Tree[1000000];
  • int nNodeCount = 0; // 记录线段树结点数
  • int mid(CNode *pRoot) {
  • return (pRoot->l + pRoot->r) >> 1;
  • }
  • void buildTree(CNode *pRoot, int l, int r) {
  • pRoot->l = l;
  • pRoot->r = r;
  • pRoot->bCovered = false;
  • if (l == r)
  • return;
  • nNodeCount++;
  • pRoot->pLeft = Tree + nNodeCount;
  • nNodeCount++;
  • pRoot->pRight = Tree + nNodeCount;
  • buildTree(pRoot->pLeft, l, (l + r) >> 1);
  • buildTree(pRoot->pRight, ((l +r) >> 1) + 1, r);
  • }
  • bool post(CNode *pRoot, int l, int r) {
  • if (pRoot->bCovered)
  • return false;
  • if (pRoot->l == l && pRoot->r == r) {
  • pRoot->bCovered = true;
  • return true;
  • }
  • bool result;
  • if (r <= mid(pRoot))
  • result = post(pRoot->pLeft, l, r);
  • else if (l > mid(pRoot))
  • result = post(pRoot->pRight, l, r);
  • else {
  • bool b1 = post(pRoot->pLeft, l, mid(pRoot));
  • bool b2 = post(pRoot->pRight, mid(pRoot) + 1, r);
  • result = b1 || b2;
  • }
  • if (pRoot->pLeft->bCovered && pRoot->pRight->bCovered)
  • pRoot->bCovered = true;
  • return result;
  • }
  • int main() {
  • int t,i,j,k;
  • scanf("%d",&t);
  • int nCaseNo = 0;
  • while (t--) {
  • nCaseNo++;
  • scanf("%d",&n);
  • int nCount = 0; // 记录海报端点数
  • for (i = 0; i < n; i++) {
  • scanf("%d%d", &posters[i].l, &posters[i].r);
  • x[nCount++] = posters[i].l;
  • x[nCount++] = posters[i].r;
  • }
  • sort(x, x + nCount);
  • nCount = unique(x, x + nCount) - x; // 去重
  • // 离散化
  • int nlntervalNo = 0;
  • for (i = 0; i < nCount; i++) {
  • hashArr[x[i]] = nlntervalNo;
  • if (i < nCount - 1) {
  • if (x[i + 1] - x[i] == 1)
  • nlntervalNo++;
  • else
  • nlntervalNo += 2;
  • }
  • }
  • buildTree(Tree, 0, nlntervalNo);
  • int nSum = 0;
  • for (i = n - 1; i >= 0; i--) {
  • if (post(Tree, hashArr[posters[i].l], hashArr[posters[i].r]))
  • nSum++;
  • }
  • printf("%d\n",nSum);
  • }
  • return 0;
  • }

HDU 1698展开目录

题目大意是说:整个区间内点的初始值是 1,有 m 次操作,将 l 到 r 区间的值改为 val

简单的线段树区间更新区间查询问题,注意这里是更新,不是累加累减,因此前面的修改对后面的修改不具有传递性或者相关性,因此在传递 lazy 标记的时候,不能用 +=,而应该是 =

  • import java.util.*;
  • public class segmentTree {
  • static Node[] node;
  • static int SUM;
  • public static void main(String[] args) {
  • Scanner cin = new Scanner(System.in);
  • int T = cin.nextInt();
  • for (int p = 1; p <= T; p++) {
  • int n = cin.nextInt();
  • node = new Node[n << 2];
  • make(1, n, 0);
  • int m = cin.nextInt();
  • for (int i = 0; i < m; i++) {
  • int l = cin.nextInt();
  • int r = cin.nextInt();
  • int val = cin.nextInt();
  • update(l, r, val, 0);
  • }
  • SUM = 0;
  • query(1, n, 0);
  • System.out.println("Case " + p + ": The total value of the hook is " + SUM + ".");
  • }
  • }
  • static void query(int l, int r, int idx) {
  • if (l <= node[idx].l && r >= node[idx].r)
  • SUM += node[idx].sum;
  • else {
  • if (node[idx].inc != 0)
  • pushDown(idx);
  • int mid = (node[idx].l + node[idx].r) >> 1;
  • if (r <= mid)
  • query(l, r, (idx << 1) | 1);
  • else if (l > mid)
  • query(l, r, (idx << 1) + 2);
  • else {
  • query(l, r, (idx << 1) | 1);
  • query(l, r, (idx << 1) + 2);
  • }
  • }
  • }
  • static void update(int l, int r, int val, int idx) {
  • if (l <= node[idx].l && r >= node[idx].r) {
  • node[idx].sum = (node[idx].r - node[idx].l + 1) * val;
  • node[idx].inc = val;
  • }
  • else {
  • if (node[idx].inc != 0)
  • pushDown(idx);
  • int mid = (node[idx].l + node[idx].r) >> 1;
  • if (r <= mid)
  • update(l, r, val, (idx << 1) | 1);
  • else if (l > mid)
  • update(l, r, val, (idx << 1) + 2);
  • else {
  • update(l, r, val, (idx << 1) | 1);
  • update(l, r, val, (idx << 1) + 2);
  • }
  • pushUp(idx);
  • }
  • }
  • static void pushDown(int idx) {
  • int val = node[idx].inc;
  • int mid = (node[idx].l + node[idx].r) >> 1;
  • node[(idx << 1) | 1].sum = (mid - node[(idx << 1) | 1].l + 1) * val;
  • node[(idx << 1) + 2].sum = (node[(idx << 1) + 2].r - mid) * val;
  • node[(idx << 1) | 1].inc = val;
  • node[(idx << 1) + 2].inc = val;
  • node[idx].inc = 0;
  • }
  • static void make(int l, int r, int idx) {
  • node[idx] = new Node(l, r);
  • if (l == r)
  • node[idx].sum = 1;
  • else {
  • int mid = (node[idx].l + node[idx].r) >> 1;
  • make(l, mid, (idx << 1) | 1);
  • make(mid + 1, r, (idx << 1) + 2);
  • pushUp(idx);
  • }
  • }
  • static void pushUp(int idx) {
  • node[idx].sum = node[(idx << 1) | 1].sum + node[(idx << 1) + 2].sum;
  • }
  • }
  • class Node {
  • int l, r, sum, inc;
  • Node(int l, int r) {
  • this.l = l;
  • this.r = r;
  • }
  • }
Last Modified: August 17, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

4 Comments
  1. LINNO LINNO

    find 错别字 争取的离散方法

    1. mathor mathor

      @LINNO 已修改,感谢提醒

  2. Jk的狗 Jk 的狗

    如果让一个区间内的数变相反数呢

    1. _xm_ _xm_

      @Jk 的狗这个不难吧,相当于区间乘 - 1