一起学习交流~

线段树

算法 laomuji 8个月前 (02-12) 446次浏览 已收录 0个评论

什么是线段树

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,实际应用时一般还要开4N的数组以免越界,因此有时需要离散化让空间压缩。

构造线段树

区间求和查询

节点修改

代码

#include<iostream>
using namespace std;
#define MAXLEN 100

/// <summary>
/// 构造线段树
/// </summary>
/// <param name="arr">用来构造线段树的数组</param>
/// <param name="tree">返回构造的线段树数组</param>
/// <param name="currentNode">当前节点下标</param>
/// <param name="left">数组左侧开始下标</param>
/// <param name="right">数组右侧结束下标</param>
void buildSegmentTree(int *arr,int*tree,int currentNode,int left,int right) {
    if (left == right) {
        tree[currentNode] = arr[left];
        return;
    }
    int leftNode = 2 * currentNode + 1;
    int rightNode = 2 * currentNode + 2;
    int mid = (left + right) / 2;
    buildSegmentTree(arr, tree, leftNode, left, mid);
    buildSegmentTree(arr, tree, rightNode, mid + 1, right);
    tree[currentNode] = tree[leftNode] + tree[rightNode];
}
/// <summary>
/// 区间求和
/// </summary>
/// <param name="tree">线段树数组</param>
/// <param name="currentNode">当前节点下标</param>
/// <param name="left">数组左侧开始下标</param>
/// <param name="right">数组右侧结束下标</param>
/// <param name="start">区间开始下标</param>
/// <param name="end">区间结束下标</param>
/// <returns>区间和</returns>
int query(int* tree, int currentNode, int left, int right,int start,int end) {
    //start>right或者end<left时,与该区间不相交
    if (start > right || end < left)return 0;
    //left=right时 表示已经是叶子节点了,直接返回节点值
    if (left == right) return tree[currentNode];

    //如果 left和right 在 start和end的范围内,直接返回树上记录的值,不需要向下查询
    if (left >= start && right <= end)return tree[currentNode];

    //将区间分成两段,在左侧和右侧查询区间值
    int mid = (left + right) / 2;
    int leftNode  = 2 * currentNode + 1;
    int rightNode = 2 * currentNode + 2;
    int leftQuery = query(tree, leftNode, left, mid, start, end);
    int rightQuery = query(tree, rightNode, mid + 1, right, start, end);
    return leftQuery + rightQuery;
}
/// <summary>
/// 修改节点值
/// </summary>
/// <param name="arr">原始数组</param>
/// <param name="tree">线段树数组</param>
/// <param name="currentNode">当前下标</param>
/// <param name="left">数组左侧开始下标</param>
/// <param name="right">数组右侧结束下标</param>
/// <param name="idx">原始数组数据的下标</param>
/// <param name="val">需要修改的值</param>
void update(int* arr, int* tree, int currentNode, int left, int right, int idx, int val) {
    if (left == right) {
        //当left=right
        //此时left和right的值一定为idx
        //否则 idx一定不正确
        if (left == idx) {
            arr[idx] = val;
            tree[currentNode] = val;
        }
        else {
            cout << "idx不正确" << endl;
        }
        return;
    }
    int leftNode = 2 * currentNode + 1;
    int rightNode = 2 * currentNode + 2;
    int mid = (left + right) / 2;
    if (idx <= mid) {
        //如果需要修改的节点在左侧
        update(arr, tree, leftNode, left, mid, idx, val);
    }
    else {
        //如果需要修改的节点在右侧
        update(arr, tree, rightNode, mid + 1, right, idx, val);
    }
    tree[currentNode] = tree[leftNode] + tree[rightNode];
}

void prtTree(int *tree,int len) {
    for (int i = 0; i < len; i++) {
        cout << tree[i] << " ";
    }
    cout << endl;
}

int main() {
    int len = 6;
    int arr[MAXLEN] = { 4,5,3,7,8,2 };//构造线段树的数据
    int tree[MAXLEN*4] = {0};//存放构造的线段树,一般开4倍,防止溢出
    buildSegmentTree(arr, tree, 0, 0, len - 1);
    cout << "构造线段树:" << endl;
    prtTree(tree, 15);

    cout << "查询1-2:" << endl;
    cout << query(tree, 0, 0, len - 1, 1, 2) << endl;

    cout << "查询0-2:" << endl;
    cout << query(tree, 0, 0, len - 1, 0, 2) << endl;

    cout << "查询0-4:" << endl;
    cout << query(tree, 0, 0, len - 1, 0, 4) << endl;

    cout << "更新下标4的值为1" << endl;
    update(arr, tree, 0, 0, len - 1, 4, 1);
    prtTree(tree, 15);
    return 0;
}

力扣题目验证

力扣 307题 区域和检索 – 数组可修改
刚好可以测试,于是又重新写了一遍

class NumArray {
private:
    vector<int> arr;
    vector<int> tree;

    void initSegmentTree(int currentNode,int start,int end){
        if(start == end){
            tree[currentNode] = arr[start];
            return;
        }
        int mid = (start+end)/2;
        int leftNode  = currentNode * 2 + 1;
        int rightNode = currentNode * 2 + 2;
        initSegmentTree(leftNode,start,mid);
        initSegmentTree(rightNode,mid+1,end);
        tree[currentNode] = tree[leftNode]+tree[rightNode];
    }

    int queryRange(int left,int right,int currentNode,int start,int end){
        if(left > end || right < start)return 0;//查询区间和当前区间不相交
        if(start == end)return tree[currentNode];//叶子节点
        if(start>=left && end <=right)return tree[currentNode];//如果当前区间是查询区间的子集
        int mid = (start+end)/2;
        int leftNode  = currentNode * 2 + 1;
        int rightNode = currentNode * 2 + 2;
        int leftSum = queryRange(left,right,leftNode,start,mid);
        int rightSum= queryRange(left,right,rightNode,mid+1,end);
        return leftSum+rightSum;
    }

    void updateRange(int currentNode,int start,int end ,int index,int val){
        if(start == end){
            //实际上这里应该再判断一下 start == index,可以判断index是否合法
            //但输入的数据是合法的,就不用判断了
            arr[index]=val;
            tree[currentNode]=val;
            return;
        }

        int mid = (start+end)/2;
        int leftNode  = currentNode * 2 + 1;
        int rightNode = currentNode * 2 + 2;
        if(index<=mid){
            //index小于等于mid 表示在 左侧区间修改
            updateRange(leftNode,start,mid,index,val);
        }else{
            //index大于mid 表示在右侧区间修改
            updateRange(rightNode,mid+1,end,index,val);
        }
        tree[currentNode]=tree[leftNode]+tree[rightNode];
    }
public:
    NumArray(vector<int>& nums) {
        if(nums.size()==0)return;
        arr = nums;
        tree.resize(arr.size()*4);//开辟四倍大小防止溢出
        initSegmentTree(0,0,arr.size()-1);
    }

    void update(int index, int val) {
        updateRange(0,0,arr.size()-1,index,val);
    }

    int sumRange(int left, int right) {
        return queryRange(left,right,0,0,arr.size()-1);
    }
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray* obj = new NumArray(nums);
 * obj->update(index,val);
 * int param_2 = obj->sumRange(left,right);
 */
喜欢 (0)
订阅评论
提醒
guest
0 评论
内联反馈
查看所有评论