Advanced data structure - segment tree, weight segment tree (Java & JS & Python)

Primer

Now given an array arr = [4, 7, 5, 3, 8, 9, 0, 1, 2, 6], arr.length = n, perform the following operations repeatedly irregularly:

  • Query the maximum value max within the specified interval [l, r] of arr
  • Query the sum of the elements in the specified interval [l, r] of arr
  • arr adds C to the element at the specified index i position or overwrites it with C
  • arr adds C or overrides C for each element value in the specified interval [l, r]

in:

  • The time complexity of the query (interval max, interval sum) is O(n)
  • The time complexity of updating a single value is O(1)
  • The time complexity of interval update is O(n) 

If you need to solve the sum of the specified interval of arr multiple times, you can use the prefix and optimization. For details, see:

Algorithm Design - Prefix Sum & Differential Sequence - Blogs Outside of Fucheng - CSDN Blog

However, in the above requirements, the arr array changes (single value update, interval update), so the prefix and array of the arr array also change. Whenever arr is updated, the prefix and array need to be regenerated, so O( 1) The time complexity is calculated as the interval sum.

If, say, perform any of the above operations m times (each operation can be different), the final time complexity is O(m * n)

So is there a more efficient algorithm?

Segment tree concept

The line segment tree is a binary tree based on the idea of ​​divide and conquer. Each node of the line segment tree corresponds to an interval [l, r] of the arr array

  • The leaf node of the line segment tree corresponds to l == r in the interval
  • If the non-leaf node of the line segment tree corresponds to the interval [l, r], suppose mid = (l + r) / 2
  1. The left child node corresponds to the interval [l, mid]
  2. The right child node corresponds to the interval [mid + 1, r]

The nodes of the line segment tree also record the result values ​​in the corresponding interval [l, r], such as the maximum value of the interval, the sum of the interval.

That is, we can think that the line segment tree node contains three basic information:

  • interval left boundary l
  • the right boundary of the interval r
  • Interval result value val

For example, the array arr = [4, 7, 5, 3, 8, 9, 0, 1, 2, 6], the corresponding segment tree diagram is as follows:

Among them, l==r of the leaf node in the line segment tree, assuming i == l == r, then the value of the leaf node of the line segment tree is arr[i].

If we need to find the maximum value of the interval, the val of each parent node is equivalent to the larger of the vals of its two child nodes, so the line segment tree can be obtained as follows:

With the above structure, we can achieve O(logN) time complexity and find the maximum value of any interval.

For example, if we want to find the maximum value of the interval [3, 8], it is equivalent to divide and conquer from the root node, and find the result values ​​of the three intervals [3, 4], [5, 7], [8, 8]. Take the larger value as the maximum value of the [3, 8] interval.

Therefore, it is a very efficient strategy to query interval information based on the line segment tree.

The underlying container of the segment tree

The line segment tree is actually a binary tree, and except for the last layer that may not be full, the rest of the layers must be full.

For a full binary tree, we can store it in an array, such as the full binary tree shown below:

In a full binary tree, if the serial number of the parent node is k (k>=1), the serial number of its left child node is 2*k, and the serial number of its right child node is 2*k+1

Therefore, if the full binary tree node number corresponds to the array index, the relationship is as shown in the figure above.

That is, the index k in the array records the node value of the node number k in the binary tree. 

Therefore, as long as we imagine the line segment tree as a full binary tree, it can be stored in an array, so how long does the line segment tree need to apply for?

Assuming that the interval [l, r] described by the line segment tree has a length of n, it means that the line segment tree has n leaf nodes

The second-to-last layer has at most n nodes, and the first to second-to-last layer of the line segment tree is a full binary tree, and the full binary tree has the following properties:

If there are x nodes in the last layer of a full binary tree, the sum of the number of nodes in all previous layers must be x-1.

The proof is also very easy, the number of nodes in each layer of the full binary tree:

Layer 1, with 2^0 nodes

Layer 2, with 2^1 nodes

Layer 3, with 2^2 nodes

....

Assuming there are only 3 layers, there must be: 2^0 + 2^1 = 2^2 - 1

If the second-to-last layer of the line segment tree has at most n nodes, then the first-to-last layer of the line segment tree has at most n-1 nodes,

That is, there are at most 2n-1 nodes in the first layer to the penultimate layer of the line segment tree.

Then if the last layer of the line segment tree is filled, there must be at most 2n nodes.

Therefore, the line segment tree has at most 4n nodes in total, that is, as long as an array space of 4n length is opened up, all nodes of the line segment tree can be stored.

Line segment tree construction

The underlying container of the line segment tree is an array, which we assume is tree.

If the length of the original array arr to be queried for interval information is n, then the underlying container array of the line segment tree needs to define a length of 4n.

The relationship between tree array elements and line segment tree nodes is as follows:

  • tree array element → line segment tree node.
  • The index of the tree array element → the serial number of the line segment tree node

The nodes of the line segment tree contain three basic information:

  • interval left boundary l
  • the right boundary of the interval r
  • Interval result value val (such as interval sum, interval maximum value)

Therefore, we can define a Node class to record node information. Therefore, the tree array is also an array of Node type.

We can use the diagram to see what the tree array looks like

Build a line segment tree, that is, build a tree array in the above figure.

The index k of the tree array is the serial number k of the line segment tree node.

tree[k] = Node {l, r, max}

The meaning of the above pseudo code is: the line segment tree node k corresponds to the interval of the arr array [l, r], and records the maximum value max in this interval

We can complete the construction of line segment tree by divide and conquer recursively.

For example, we already know the line segment tree node with k=1, and the maintained arr interval is [0, 9]. Now we need to find the maximum value of this interval?

Since the line segment is a binary tree based on the divide-and-conquer idea, the [0, 9] interval can be divided into [0, 4], and [5, 9]

That is, the problem of the maximum value of the [0, 9] interval is changed into two smaller sub-problems of the maximum value of the [0, 4] interval and the maximum value of the [5, 9] interval.

The interval [0, 4] is exactly the interval maintained by k=2 nodes, and [5, 9] is the interval maintained by k=3 nodes.

After that, continue to follow this logic to recursively solve the maximum value of the interval [0, 4] and [5, 9].

Until, l == r of the interval after being divided into two, that is, when the leaf node is reached, the maximum value of the interval [l, r] at this time is arr[l] or arr[r], and then you can start backtracking.

During the backtracking process, the maximum value of the interval of the parent node is equal to the larger of the maximum values ​​of the intervals of its two nodes.

The specific code implementation is as follows (including test code):

JS code implementation

// 线段树节点定义
class Node {
  constructor(l, r) {
    this.l = l; // 区间左边界
    this.r = r; // 区间右边界
    this.max = undefined; // 区间内最大值
  }
}

// 线段树定义
class SegmentTree {
  constructor(arr) {
    // arr是要执行查询区间最大值的原始数组
    this.arr = arr;
    // 线段树底层数据结构,其实就是一个数组,我们定义其为tree,如果arr数组长度为n,则tree数组需要4n的长度
    this.tree = new Array(arr.length * 4);
    // 从根节点开始构建,线段树根节点序号k=1,对应的区间范围是[0, arr.length-1]
    this.build(1, 0, arr.length - 1);
  }

  /**
   * 线段树构建
   * @param {*} k 线段树节点序号
   * @param {*} l 节点对应的区间范围左边界
   * @param {*} r 节点对应的区间范围右边界
   */
  build(k, l, r) {
    // 初始化线段树节点, 即建立节点序号k和区间范围[l, r]的联系
    this.tree[k] = new Node(l, r);

    // 如果l==r, 则说明k节点是线段树的叶子节点
    if (l == r) {
      // 而线段树叶子节点的结果值就是arr[l]或arr[r]本身
      this.tree[k].max = arr[r];
      // 回溯
      return;
    }

    // 如果l!=r, 则说明k节点不是线段树叶子节点,因此其必有左右子节点,左右子节点的分界位置是mid
    const mid = (l + r) >> 1; // 等价于Math.floor((l + r) / 2)

    // 递归构建k节点的左子节点,序号为2 * k,对应区间范围是[l, mid]
    this.build(2 * k, l, mid);
    // 递归构建k节点的右子节点,序号为2 * k + 1,对应区间范围是[mid+1, r]
    this.build(2 * k + 1, mid + 1, r);

    // k节点的结果值,取其左右子节点结果值的较大值
    this.tree[k].max = Math.max(this.tree[2 * k].max, this.tree[2 * k + 1].max);
  }
}

// 测试
const arr = [4, 7, 5, 3, 8, 9, 0, 1, 2, 6];

const tree = new SegmentTree(arr).tree;

console.log("k\t| tree[k]");
for (let k = 0; k < tree.length; k++) {
  if (tree[k]) {
    console.log(
      `${k}\t| Node{ l: ${tree[k].l}, r: ${tree[k].r}, max: ${tree[k].max}}`
    );
  } else {
    console.log(`${k}\t| null`);
  }
}

Java code implementation

// 线段树定义
public class SegmentTree {
  // 线段树节点定义
  static class Node {
    int l; // 区间左边界
    int r; // 区间右边界
    int max; // 区间内最大值

    public Node(int l, int r) {
      this.l = l;
      this.r = r;
    }
  }

  int[] arr;

  Node[] tree;

  public SegmentTree(int[] arr) {
    // arr是要执行查询区间最大值的原始数组
    this.arr = arr;
    // 线段树底层数据结构,其实就是一个数组,我们定义其为tree,如果arr数组长度为n,则tree数组需要4n的长度
    this.tree = new Node[arr.length * 4];
    // 从根节点开始构建,线段树根节点序号k=1,对应的区间范围是[0, arr.length-1]
    this.build(1, 0, arr.length - 1);
  }

  /**
   * 线段树构建
   *
   * @param k 线段树节点序号
   * @param l 节点对应的区间范围左边界
   * @param r 节点对应的区间范围右边界
   */
  private void build(int k, int l, int r) {
    // 初始化线段树节点, 即建立节点序号k和区间范围[l, r]的联系
    this.tree[k] = new Node(l, r);

    // 如果l==r, 则说明k节点是线段树的叶子节点
    if (l == r) {
      // 而线段树叶子节点的结果值就是arr[l]或arr[r]本身
      this.tree[k].max = this.arr[r];
      // 回溯
      return;
    }

    // 如果l!=r, 则说明k节点不是线段树叶子节点,因此其必有左右子节点,左右子节点的分界位置是mid
    int mid = (l + r) >> 1;

    // 递归构建k节点的左子节点,序号为2 * k,对应区间范围是[l, mid]
    this.build(2 * k, l, mid);
    // 递归构建k节点的右子节点,序号为2 * k + 1,对应区间范围是[mid+1, r]
    this.build(2 * k + 1, mid + 1, r);

    // k节点的结果值,取其左右子节点结果值的较大值
    this.tree[k].max = Math.max(this.tree[2 * k].max, this.tree[2 * k + 1].max);
  }

  // 测试
  public static void main(String[] args) {
    int[] arr = {4, 7, 5, 3, 8, 9, 0, 1, 2, 6};

    Node[] tree = new SegmentTree(arr).tree;

    System.out.println("k\t| tree[k]");
    for (int k = 0; k < tree.length; k++) {
      if (tree[k] == null) {
        System.out.println(k + "\t| null");
      } else {
        System.out.println(
            k + "\t| Node{ l: " + tree[k].l + ", r: " + tree[k].r + ", max: " + tree[k].max + "}");
      }
    }
  }
}

Python code implementation

# 线段树节点定义
class Node:
    def __init__(self):
        self.l = None
        self.r = None
        self.mx = None


# 线段树定义
class SegmentTree:
    def __init__(self, lst):
        # lst是要执行查询区间最大值的原始数组
        self.lst = lst
        # 线段树底层数据结构,其实就是一个数组,我们定义其为tree,如果lst数组长度为n,则tree数组需要4n的长度
        self.tree = [Node() for _ in range(len(lst) * 4)]
        # 从根节点开始构建,线段树根节点序号k=1,对应的区间范围是[0, len(lst) - 1]
        self.build(1, 0, len(lst) - 1)

    def build(self, k, l, r):
        """
        线段树构建
        :param k: 线段树节点序号
        :param l: 节点对应的区间范围左边界
        :param r: 节点对应的区间范围右边界
        """

        # 初始化线段树节点, 即建立节点序号k和区间范围[l, r]的联系
        self.tree[k].l = l
        self.tree[k].r = r

        # 如果l==r, 则说明k节点是线段树的叶子节点
        if l == r:
            # 而线段树叶子节点的结果值就是lst[l]或lst[r]本身
            self.tree[k].mx = self.lst[r]
            # 回溯
            return

        # 如果l!=r, 则说明k节点不是线段树叶子节点,因此其必有左右子节点,左右子节点的分界位置是mid
        mid = (l + r) >> 1

        # 递归构建k节点的左子节点,序号为2 * k,对应区间范围是[l, mid]
        self.build(2 * k, l, mid)
        # 递归构建k节点的右子节点,序号为2 * k + 1,对应区间范围是[mid+1, r]
        self.build(2 * k + 1, mid + 1, r)

        # k节点的结果值,取其左右子节点结果值的较大值
        self.tree[k].mx = max(self.tree[2 * k].mx, self.tree[2 * k + 1].mx)


# 测试代码
lst = [4, 7, 5, 3, 8, 9, 0, 1, 2, 6]
print("k\t| tree[k]")
for k, node in enumerate(SegmentTree(lst).tree):
    if node.mx:
        print(f"{k}\t| Node[ l: {node.l}, r: {node.r}, mx: {node.mx} ]")
    else:
        print(f"{k}\t| null")

Query any interval result value

Guess you like

Origin blog.csdn.net/qfc_128220/article/details/131720641