mirror of
https://github.com/halfrost/LeetCode-Go.git
synced 2025-07-05 00:25:22 +08:00
添加 segmentTree 模板
This commit is contained in:
251
template/SegmentTree.go
Normal file
251
template/SegmentTree.go
Normal file
@ -0,0 +1,251 @@
|
||||
package template
|
||||
|
||||
// SegmentTree define
|
||||
type SegmentTree struct {
|
||||
data, tree, lazy []int
|
||||
left, right int
|
||||
merge func(i, j int) int
|
||||
}
|
||||
|
||||
func (st *SegmentTree) init(nums []int, oper func(i, j int) int) {
|
||||
st.merge = oper
|
||||
|
||||
data, tree, lazy := make([]int, len(nums)), make([]int, 4*len(nums)), make([]int, 4*len(nums))
|
||||
for i := 0; i < len(nums); i++ {
|
||||
data[i] = nums[i]
|
||||
}
|
||||
st.data, st.tree, st.lazy = data, tree, lazy
|
||||
if len(nums) > 0 {
|
||||
st.buildSegmentTree(0, 0, len(nums)-1)
|
||||
}
|
||||
}
|
||||
|
||||
// 在 treeIndex 的位置创建 [left....right] 区间的线段树
|
||||
func (st *SegmentTree) buildSegmentTree(treeIndex, left, right int) {
|
||||
if left == right {
|
||||
st.tree[treeIndex] = st.data[left]
|
||||
return
|
||||
}
|
||||
leftTreeIndex, rightTreeIndex := st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
midTreeIndex := left + (right-left)/2
|
||||
st.buildSegmentTree(leftTreeIndex, left, midTreeIndex)
|
||||
st.buildSegmentTree(rightTreeIndex, midTreeIndex+1, right)
|
||||
st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex])
|
||||
}
|
||||
|
||||
func (st *SegmentTree) leftChild(index int) int {
|
||||
return 2*index + 1
|
||||
}
|
||||
|
||||
func (st *SegmentTree) rightChild(index int) int {
|
||||
return 2*index + 2
|
||||
}
|
||||
|
||||
// 查询 [left....right] 区间内的值
|
||||
func (st *SegmentTree) query(left, right int) int {
|
||||
if len(st.data) > 0 {
|
||||
return st.queryInTree(0, 0, len(st.data)-1, left, right)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// 在以 treeIndex 为根的线段树中 [left...right] 的范围里,搜索区间 [queryLeft...queryRight] 的值
|
||||
func (st *SegmentTree) queryInTree(treeIndex, left, right, queryLeft, queryRight int) int {
|
||||
if left == queryLeft && right == queryRight {
|
||||
return st.tree[treeIndex]
|
||||
}
|
||||
midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
if queryLeft >= midTreeIndex+1 {
|
||||
return st.queryInTree(rightTreeIndex, midTreeIndex+1, right, queryLeft, queryRight)
|
||||
} else if queryRight <= midTreeIndex {
|
||||
return st.queryInTree(leftTreeIndex, left, midTreeIndex, queryLeft, queryRight)
|
||||
}
|
||||
return st.merge(st.queryInTree(leftTreeIndex, left, midTreeIndex, queryLeft, midTreeIndex),
|
||||
st.queryInTree(rightTreeIndex, midTreeIndex+1, right, midTreeIndex+1, queryRight))
|
||||
}
|
||||
|
||||
// 查询 [left....right] 区间内的值
|
||||
func (st *SegmentTree) queryLazy(left, right int) int {
|
||||
if len(st.data) > 0 {
|
||||
return st.queryLazyInTree(0, 0, len(st.data)-1, left, right)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (st *SegmentTree) queryLazyInTree(treeIndex, left, right, queryLeft, queryRight int) int {
|
||||
midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
if left > queryRight || right < queryLeft { // segment completely outside range
|
||||
return 0 // represents a null node
|
||||
}
|
||||
if st.lazy[treeIndex] != 0 { // this node is lazy
|
||||
for i := 0; i < right-left+1; i++ {
|
||||
st.tree[treeIndex] = st.merge(st.tree[treeIndex], st.lazy[treeIndex])
|
||||
// st.tree[treeIndex] += (right - left + 1) * st.lazy[treeIndex] // normalize current node by removing lazinesss
|
||||
}
|
||||
if left != right { // update lazy[] for children nodes
|
||||
st.lazy[leftTreeIndex] = st.merge(st.lazy[leftTreeIndex], st.lazy[treeIndex])
|
||||
st.lazy[rightTreeIndex] = st.merge(st.lazy[rightTreeIndex], st.lazy[treeIndex])
|
||||
// st.lazy[leftTreeIndex] += st.lazy[treeIndex]
|
||||
// st.lazy[rightTreeIndex] += st.lazy[treeIndex]
|
||||
}
|
||||
st.lazy[treeIndex] = 0 // current node processed. No longer lazy
|
||||
}
|
||||
if queryLeft <= left && queryRight >= right { // segment completely inside range
|
||||
return st.tree[treeIndex]
|
||||
}
|
||||
if queryLeft > midTreeIndex {
|
||||
return st.queryLazyInTree(rightTreeIndex, midTreeIndex+1, right, queryLeft, queryRight)
|
||||
} else if queryRight <= midTreeIndex {
|
||||
return st.queryLazyInTree(leftTreeIndex, left, midTreeIndex, queryLeft, queryRight)
|
||||
}
|
||||
// merge query results
|
||||
return st.merge(st.queryLazyInTree(leftTreeIndex, left, midTreeIndex, queryLeft, midTreeIndex),
|
||||
st.queryLazyInTree(rightTreeIndex, midTreeIndex+1, right, midTreeIndex+1, queryRight))
|
||||
}
|
||||
|
||||
// 更新 index 位置的值
|
||||
func (st *SegmentTree) update(index, val int) {
|
||||
if len(st.data) > 0 {
|
||||
st.updateInTree(0, 0, len(st.data)-1, index, val)
|
||||
}
|
||||
}
|
||||
|
||||
// 以 treeIndex 为根,更新 index 位置上的值为 val
|
||||
func (st *SegmentTree) updateInTree(treeIndex, left, right, index, val int) {
|
||||
if left == right {
|
||||
st.tree[treeIndex] = val
|
||||
return
|
||||
}
|
||||
midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
if index >= midTreeIndex+1 {
|
||||
st.updateInTree(rightTreeIndex, midTreeIndex+1, right, index, val)
|
||||
} else {
|
||||
st.updateInTree(leftTreeIndex, left, midTreeIndex, index, val)
|
||||
}
|
||||
st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex])
|
||||
}
|
||||
|
||||
// 更新 [updateLeft....updateRight] 位置的值
|
||||
func (st *SegmentTree) updateLazy(updateLeft, updateRight, val int) {
|
||||
if len(st.data) > 0 {
|
||||
st.updateLazyInTree(0, 0, len(st.data)-1, updateLeft, updateRight, val)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *SegmentTree) updateLazyInTree(treeIndex, left, right, updateLeft, updateRight, val int) {
|
||||
midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
if st.lazy[treeIndex] != 0 { // this node is lazy
|
||||
for i := 0; i < right-left+1; i++ {
|
||||
st.tree[treeIndex] = st.merge(st.tree[treeIndex], st.lazy[treeIndex])
|
||||
//st.tree[treeIndex] += (right - left + 1) * st.lazy[treeIndex] // normalize current node by removing laziness
|
||||
}
|
||||
if left != right { // update lazy[] for children nodes
|
||||
st.lazy[leftTreeIndex] = st.merge(st.lazy[leftTreeIndex], st.lazy[treeIndex])
|
||||
st.lazy[rightTreeIndex] = st.merge(st.lazy[rightTreeIndex], st.lazy[treeIndex])
|
||||
// st.lazy[leftTreeIndex] += st.lazy[treeIndex]
|
||||
// st.lazy[rightTreeIndex] += st.lazy[treeIndex]
|
||||
}
|
||||
st.lazy[treeIndex] = 0 // current node processed. No longer lazy
|
||||
}
|
||||
|
||||
if left > right || left > updateRight || right < updateLeft {
|
||||
return // out of range. escape.
|
||||
}
|
||||
|
||||
if updateLeft <= left && right <= updateRight { // segment is fully within update range
|
||||
for i := 0; i < right-left+1; i++ {
|
||||
st.tree[treeIndex] = st.merge(st.tree[treeIndex], val)
|
||||
//st.tree[treeIndex] += (right - left + 1) * val // update segment
|
||||
}
|
||||
if left != right { // update lazy[] for children
|
||||
st.lazy[leftTreeIndex] = st.merge(st.lazy[leftTreeIndex], val)
|
||||
st.lazy[rightTreeIndex] = st.merge(st.lazy[rightTreeIndex], val)
|
||||
// st.lazy[leftTreeIndex] += val
|
||||
// st.lazy[rightTreeIndex] += val
|
||||
}
|
||||
return
|
||||
}
|
||||
st.updateLazyInTree(leftTreeIndex, left, midTreeIndex, updateLeft, updateRight, val)
|
||||
st.updateLazyInTree(rightTreeIndex, midTreeIndex+1, right, updateLeft, updateRight, val)
|
||||
// merge updates
|
||||
st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex])
|
||||
}
|
||||
|
||||
// SegmentCountTree define
|
||||
type SegmentCountTree struct {
|
||||
data, tree []int
|
||||
left, right int
|
||||
merge func(i, j int) int
|
||||
}
|
||||
|
||||
func (st *SegmentCountTree) init(nums []int, oper func(i, j int) int) {
|
||||
st.merge = oper
|
||||
|
||||
data, tree := make([]int, len(nums)), make([]int, 4*len(nums))
|
||||
for i := 0; i < len(nums); i++ {
|
||||
data[i] = nums[i]
|
||||
}
|
||||
st.data, st.tree = data, tree
|
||||
}
|
||||
|
||||
// 在 treeIndex 的位置创建 [left....right] 区间的线段树
|
||||
func (st *SegmentCountTree) buildSegmentTree(treeIndex, left, right int) {
|
||||
if left == right {
|
||||
st.tree[treeIndex] = st.data[left]
|
||||
return
|
||||
}
|
||||
leftTreeIndex, rightTreeIndex := st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
midTreeIndex := left + (right-left)/2
|
||||
st.buildSegmentTree(leftTreeIndex, left, midTreeIndex)
|
||||
st.buildSegmentTree(rightTreeIndex, midTreeIndex+1, right)
|
||||
st.tree[treeIndex] = st.merge(st.tree[leftTreeIndex], st.tree[rightTreeIndex])
|
||||
}
|
||||
|
||||
func (st *SegmentCountTree) leftChild(index int) int {
|
||||
return 2*index + 1
|
||||
}
|
||||
|
||||
func (st *SegmentCountTree) rightChild(index int) int {
|
||||
return 2*index + 2
|
||||
}
|
||||
|
||||
// 查询 [left....right] 区间内的值
|
||||
func (st *SegmentCountTree) query(left, right int) int {
|
||||
if len(st.data) > 0 {
|
||||
return st.queryInTree(0, 0, len(st.data)-1, left, right)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// 在以 treeIndex 为根的线段树中 [left...right] 的范围里,搜索区间 [queryLeft...queryRight] 的值,值是计数值
|
||||
func (st *SegmentCountTree) queryInTree(treeIndex, left, right, queryLeft, queryRight int) int {
|
||||
if queryRight < st.data[left] || queryLeft > st.data[right] {
|
||||
return 0
|
||||
}
|
||||
if queryLeft <= st.data[left] && queryRight >= st.data[right] || left == right {
|
||||
return st.tree[treeIndex]
|
||||
}
|
||||
midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
return st.queryInTree(rightTreeIndex, midTreeIndex+1, right, queryLeft, queryRight) +
|
||||
st.queryInTree(leftTreeIndex, left, midTreeIndex, queryLeft, queryRight)
|
||||
}
|
||||
|
||||
// 更新计数
|
||||
func (st *SegmentCountTree) updateCount(val int) {
|
||||
if len(st.data) > 0 {
|
||||
st.updateCountInTree(0, 0, len(st.data)-1, val)
|
||||
}
|
||||
}
|
||||
|
||||
// 以 treeIndex 为根,更新 [left...right] 区间内的计数
|
||||
func (st *SegmentCountTree) updateCountInTree(treeIndex, left, right, val int) {
|
||||
if val >= st.data[left] && val <= st.data[right] {
|
||||
st.tree[treeIndex]++
|
||||
if left == right {
|
||||
return
|
||||
}
|
||||
midTreeIndex, leftTreeIndex, rightTreeIndex := left+(right-left)/2, st.leftChild(treeIndex), st.rightChild(treeIndex)
|
||||
st.updateCountInTree(rightTreeIndex, midTreeIndex+1, right, val)
|
||||
st.updateCountInTree(leftTreeIndex, left, midTreeIndex, val)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user