From baad9d880dedfa4e0efb6951973376cbf8a78d42 Mon Sep 17 00:00:00 2001 From: YDZ Date: Mon, 9 Sep 2019 20:08:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20segmentTree=20=E6=A8=A1?= =?UTF-8?q?=E6=9D=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- template/SegmentTree.go | 251 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 template/SegmentTree.go diff --git a/template/SegmentTree.go b/template/SegmentTree.go new file mode 100644 index 00000000..9fcd2e4a --- /dev/null +++ b/template/SegmentTree.go @@ -0,0 +1,251 @@ +package template + +// SegmentTree define +type SegmentTree struct { + data, tree, lazy []int + left, right int + merge func(i, j int) int +} + +func (st *SegmentTree) init(nums []int, oper func(i, j int) int) { + st.merge = oper + + data, tree, lazy := make([]int, len(nums)), make([]int, 4*len(nums)), make([]int, 4*len(nums)) + for i := 0; i < len(nums); i++ { + data[i] = nums[i] + } + st.data, st.tree, st.lazy = data, tree, lazy + if len(nums) > 0 { + st.buildSegmentTree(0, 0, len(nums)-1) + } +} + +// 在 treeIndex 的位置创建 [left....right] 区间的线段树 +func (st *SegmentTree) buildSegmentTree(treeIndex, left, right int) { + if left == right { + st.tree[treeIndex] = st.data[left] + return + } + leftTreeIndex, rightTreeIndex := st.leftChild(treeIndex), st.rightChild(treeIndex) + midTreeIndex := left + (right-left)/2 + st.buildSegmentTree(leftTreeIndex, left, midTreeIndex) + st.buildSegmentTree(rightTreeIndex, midTreeIndex+1, right) + st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex]) +} + +func (st *SegmentTree) leftChild(index int) int { + return 2*index + 1 +} + +func (st *SegmentTree) rightChild(index int) int { + return 2*index + 2 +} + +// 查询 [left....right] 区间内的值 +func (st *SegmentTree) query(left, right int) int { + if len(st.data) > 0 { + return st.queryInTree(0, 0, len(st.data)-1, left, right) + } + return 0 +} + +// 在以 treeIndex 为根的线段树中 [left...right] 的范围里,搜索区间 [queryLeft...queryRight] 的值 +func (st *SegmentTree) queryInTree(treeIndex, left, right, queryLeft, queryRight int) int { + if left == queryLeft && right == queryRight { + return st.tree[treeIndex] + } + midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex) + if queryLeft >= midTreeIndex+1 { + return st.queryInTree(rightTreeIndex, midTreeIndex+1, right, queryLeft, queryRight) + } else if queryRight <= midTreeIndex { + return st.queryInTree(leftTreeIndex, left, midTreeIndex, queryLeft, queryRight) + } + return st.merge(st.queryInTree(leftTreeIndex, left, midTreeIndex, queryLeft, midTreeIndex), + st.queryInTree(rightTreeIndex, midTreeIndex+1, right, midTreeIndex+1, queryRight)) +} + +// 查询 [left....right] 区间内的值 +func (st *SegmentTree) queryLazy(left, right int) int { + if len(st.data) > 0 { + return st.queryLazyInTree(0, 0, len(st.data)-1, left, right) + } + return 0 +} + +func (st *SegmentTree) queryLazyInTree(treeIndex, left, right, queryLeft, queryRight int) int { + midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex) + if left > queryRight || right < queryLeft { // segment completely outside range + return 0 // represents a null node + } + if st.lazy[treeIndex] != 0 { // this node is lazy + for i := 0; i < right-left+1; i++ { + st.tree[treeIndex] = st.merge(st.tree[treeIndex], st.lazy[treeIndex]) + // st.tree[treeIndex] += (right - left + 1) * st.lazy[treeIndex] // normalize current node by removing lazinesss + } + if left != right { // update lazy[] for children nodes + st.lazy[leftTreeIndex] = st.merge(st.lazy[leftTreeIndex], st.lazy[treeIndex]) + st.lazy[rightTreeIndex] = st.merge(st.lazy[rightTreeIndex], st.lazy[treeIndex]) + // st.lazy[leftTreeIndex] += st.lazy[treeIndex] + // st.lazy[rightTreeIndex] += st.lazy[treeIndex] + } + st.lazy[treeIndex] = 0 // current node processed. No longer lazy + } + if queryLeft <= left && queryRight >= right { // segment completely inside range + return st.tree[treeIndex] + } + if queryLeft > midTreeIndex { + return st.queryLazyInTree(rightTreeIndex, midTreeIndex+1, right, queryLeft, queryRight) + } else if queryRight <= midTreeIndex { + return st.queryLazyInTree(leftTreeIndex, left, midTreeIndex, queryLeft, queryRight) + } + // merge query results + return st.merge(st.queryLazyInTree(leftTreeIndex, left, midTreeIndex, queryLeft, midTreeIndex), + st.queryLazyInTree(rightTreeIndex, midTreeIndex+1, right, midTreeIndex+1, queryRight)) +} + +// 更新 index 位置的值 +func (st *SegmentTree) update(index, val int) { + if len(st.data) > 0 { + st.updateInTree(0, 0, len(st.data)-1, index, val) + } +} + +// 以 treeIndex 为根,更新 index 位置上的值为 val +func (st *SegmentTree) updateInTree(treeIndex, left, right, index, val int) { + if left == right { + st.tree[treeIndex] = val + return + } + midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex) + if index >= midTreeIndex+1 { + st.updateInTree(rightTreeIndex, midTreeIndex+1, right, index, val) + } else { + st.updateInTree(leftTreeIndex, left, midTreeIndex, index, val) + } + st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex]) +} + +// 更新 [updateLeft....updateRight] 位置的值 +func (st *SegmentTree) updateLazy(updateLeft, updateRight, val int) { + if len(st.data) > 0 { + st.updateLazyInTree(0, 0, len(st.data)-1, updateLeft, updateRight, val) + } +} + +func (st *SegmentTree) updateLazyInTree(treeIndex, left, right, updateLeft, updateRight, val int) { + midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex) + if st.lazy[treeIndex] != 0 { // this node is lazy + for i := 0; i < right-left+1; i++ { + st.tree[treeIndex] = st.merge(st.tree[treeIndex], st.lazy[treeIndex]) + //st.tree[treeIndex] += (right - left + 1) * st.lazy[treeIndex] // normalize current node by removing laziness + } + if left != right { // update lazy[] for children nodes + st.lazy[leftTreeIndex] = st.merge(st.lazy[leftTreeIndex], st.lazy[treeIndex]) + st.lazy[rightTreeIndex] = st.merge(st.lazy[rightTreeIndex], st.lazy[treeIndex]) + // st.lazy[leftTreeIndex] += st.lazy[treeIndex] + // st.lazy[rightTreeIndex] += st.lazy[treeIndex] + } + st.lazy[treeIndex] = 0 // current node processed. No longer lazy + } + + if left > right || left > updateRight || right < updateLeft { + return // out of range. escape. + } + + if updateLeft <= left && right <= updateRight { // segment is fully within update range + for i := 0; i < right-left+1; i++ { + st.tree[treeIndex] = st.merge(st.tree[treeIndex], val) + //st.tree[treeIndex] += (right - left + 1) * val // update segment + } + if left != right { // update lazy[] for children + st.lazy[leftTreeIndex] = st.merge(st.lazy[leftTreeIndex], val) + st.lazy[rightTreeIndex] = st.merge(st.lazy[rightTreeIndex], val) + // st.lazy[leftTreeIndex] += val + // st.lazy[rightTreeIndex] += val + } + return + } + st.updateLazyInTree(leftTreeIndex, left, midTreeIndex, updateLeft, updateRight, val) + st.updateLazyInTree(rightTreeIndex, midTreeIndex+1, right, updateLeft, updateRight, val) + // merge updates + st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex]) +} + +// SegmentCountTree define +type SegmentCountTree struct { + data, tree []int + left, right int + merge func(i, j int) int +} + +func (st *SegmentCountTree) init(nums []int, oper func(i, j int) int) { + st.merge = oper + + data, tree := make([]int, len(nums)), make([]int, 4*len(nums)) + for i := 0; i < len(nums); i++ { + data[i] = nums[i] + } + st.data, st.tree = data, tree +} + +// 在 treeIndex 的位置创建 [left....right] 区间的线段树 +func (st *SegmentCountTree) buildSegmentTree(treeIndex, left, right int) { + if left == right { + st.tree[treeIndex] = st.data[left] + return + } + leftTreeIndex, rightTreeIndex := st.leftChild(treeIndex), st.rightChild(treeIndex) + midTreeIndex := left + (right-left)/2 + st.buildSegmentTree(leftTreeIndex, left, midTreeIndex) + st.buildSegmentTree(rightTreeIndex, midTreeIndex+1, right) + st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex]) +} + +func (st *SegmentCountTree) leftChild(index int) int { + return 2*index + 1 +} + +func (st *SegmentCountTree) rightChild(index int) int { + return 2*index + 2 +} + +// 查询 [left....right] 区间内的值 +func (st *SegmentCountTree) query(left, right int) int { + if len(st.data) > 0 { + return st.queryInTree(0, 0, len(st.data)-1, left, right) + } + return 0 +} + +// 在以 treeIndex 为根的线段树中 [left...right] 的范围里,搜索区间 [queryLeft...queryRight] 的值,值是计数值 +func (st *SegmentCountTree) queryInTree(treeIndex, left, right, queryLeft, queryRight int) int { + if queryRight < st.data[left] || queryLeft > st.data[right] { + return 0 + } + if queryLeft <= st.data[left] && queryRight >= st.data[right] || left == right { + return st.tree[treeIndex] + } + midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex) + return st.queryInTree(rightTreeIndex, midTreeIndex+1, right, queryLeft, queryRight) + + st.queryInTree(leftTreeIndex, left, midTreeIndex, queryLeft, queryRight) +} + +// 更新计数 +func (st *SegmentCountTree) updateCount(val int) { + if len(st.data) > 0 { + st.updateCountInTree(0, 0, len(st.data)-1, val) + } +} + +// 以 treeIndex 为根,更新 [left...right] 区间内的计数 +func (st *SegmentCountTree) updateCountInTree(treeIndex, left, right, val int) { + if val >= st.data[left] && val <= st.data[right] { + st.tree[treeIndex]++ + if left == right { + return + } + midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex) + st.updateCountInTree(rightTreeIndex, midTreeIndex+1, right, val) + st.updateCountInTree(leftTreeIndex, left, midTreeIndex, val) + } +}