상세 컨텐츠

본문 제목

세그먼트 트리 (Segment Tree)

본문

  • 세그먼트 트리란?

 세그먼트 트리란 데이터의 합을 가장 빠르고 간단하게 구하거나 수정할 수 있는 자료구조이다. 보통 배열을 통해 데이터의 합을 구하려면 O(N)의 시간복잡도가 나오지만 트리 구조를 이용한 세그먼트 트리의 구간합은 O(logN)의 시간 복잡도가 나온다. 따라서 훨씬 빠른 속도로 구간합을 구할 수 있다는 사실을 알 수 있다. 그렇다면 어떻게 구간합을 구하길래 이렇게 빠른 시간안에 찾을 수 있을까?

 

 세그먼트 트리를 이용하기 위해 먼저 주어진 값으로 세그먼트 트리를 만들어야 한다.

 

1. 구간 합 트리 만들기

 

만약 배열의 값이 다음과 같이 있다 해보자.

i 0 1 2 3 4 5 6 7
A[i] 1 9 5 4 8 7 2 3

배열 A의 크기 N = 8이고, 세그먼트 트리를 이용해 이 배열의 구간합을 구하고 싶을 때 구간 합 트리를 먼저 만들어야 한다. 구간 합 트리는 다음과 같다.

 

----------------------------------------------------------

                                           A[0:7]

----------------------------------------------------------

                   A[0:3]                                         A[4:7]              

----------------------------------------------------------

     A[0:1]                A[2:3]                A[4:5]                A[6:7]

----------------------------------------------------------

A[0]    A[1]        A[2]    A[3]        A[4]    A[5]        A[6]    A[7]      

----------------------------------------------------------

 

Root 노드는 배열의 맨 처음인 0에서 마지막인 7까지의 합을 넣는다.

그리고 그 범위를 2등분하여 왼쪽 자식 노드에는 0에서 7/2 = 3까지의 합을, 오른쪽 자식 노드에는 7/2 + 1 = 4에서 7까지의 합을 넣는다.

이런식으로 부모 노드의 범위를 2등분하며 내려가면 구간 합 트리가 완성된다.

이때 배열로 트리를 구현할 경우 배열의 크기 SIZE = (N보다 크거나 같은 2의 거듭제곱 값 중 가장 작은 값) * 2로 해주면 되는데 편하게 SIZE = 4 * N으로 설정하는 경우도 있다. 

 

위의 배열 A의 예시로 보면 SIZE = 8 * 2 = 16이 된다.

 

배열 A를 구간 합 트리로 만들면 다음과 같다

 

----------------------------------

                         39

----------------------------------

          19                            20           

----------------------------------

  10             9             15              5

----------------------------------

1    9        5    4        8    7        2    3      

----------------------------------

 

(구간 합 트리를 만들때는 재귀를 이용해서 자식 노드의 값을 더하는 식으로 만드는 것이 편하다.)

 

 

2. 원하는 구간의 합 구하기

 

구간 합 트리를 만들었으니 i ~ j 인덱스 값의 합을 구하는 방법을 알아보자.

원래라면 A[i] ~ A[j]의 합은 하나하나 다 더해야 구할 수 있기 때문에 시간 복잡도가 O(N)이 된다.

하지만 구간 합 트리를 이용하면 이를 O(logN)으로 만들 수 있다.

 

구간 합 트리를 만들때 처럼 재귀를 사용하여 구간 합을 구할 수 있다.

배열 A에 대하여 i = 3, j = 6일때 A[i] ~ A[j]의 합을 구하는 방법을 보면 다음과 같다.

 

우선 Root 노드부터 순서대로 탐색한다.

노드를 탐색할 때는 다음 3가지 중 1개를 골라 값을 return하면 된다.

 

    case 1. 만약 i ~ j가 노드의 범위를 포함하지 않으면 0을 return한다.

    case 2. 만약 i ~ j가 노드의 범위를 전부 포함한다면 노드의 값을 return 한다.

    case 3. 만약 i ~ j가 노드의 범위의 일부를 포함한다면 자식 노드들의 return값들의 합을 return한다.

 

1) 우선 i = 3, j = 6인데 Root 노드의 범위는 0 ~ 7이므로 case 3에 해당해 자식 노드의 return 값들을 구해야한다.

 

          ----------------------------------

                                   39

          ----------------------------------

                   19                            20                     

          ----------------------------------

            10             9             15              5

          ----------------------------------

          1    9        5    4        8    7        2    3      

          ----------------------------------

 

2) 첫번째 자식은 노드의 범위가 0 ~ 3이므로 case 3에 해당하고, 두번째 자식도 노드의 범위가 4 ~ 7이므로 case 3에 해당한다.

 

          ----------------------------------

                                   39

          ----------------------------------

                       19                            20                     

          ----------------------------------

            10             9             15              5

          ----------------------------------

          1    9        5    4        8    7        2    3      

          ----------------------------------

 

 3) 첫번째 자식은 노드의 범위가 0 ~ 1이므로 case 1에 해당해 0을 return 한다.

     두번째 자식은 노드의 범위가 2 ~ 3이므로 case 3에 해당한다.

     세번째 자식은 노드의 범위가 4 ~ 5이므로 case 2에 해당해 노드 값을 return한다.

     네번째 자식은 노드의 범위가 6 ~ 7이므로 case 3에 해당한다.

 

          ----------------------------------

                                   39

          ----------------------------------

                       19                            20                     

          ----------------------------------

            10             9             15              5

          ----------------------------------

          1    9        5    4        8    7        2    3      

          ----------------------------------

 

 4) 첫번째 자식은 노드의 범위가 2이므로 case 1에 해당해 0을 return 한다.

     두번째 자식은 노드의 범위가 3이므로 case 2에 해당해 노드 값을 return한다.

     세번째 자식은 노드의 범위가 6이므로 case 2에 해당해 노드 값을 return한다.

     네번째 자식은 노드의 범위가 7이므로 case 1에 해당해 0을 return 한다.

 

          ----------------------------------

                                   39

          ----------------------------------

                       19                            20                     

          ----------------------------------

            10             9             15              5

          ----------------------------------

          1    9        5    4        8    7        2    3      

          ----------------------------------

 

 따라서 case 2인 경우의 노드 값을 전부 더하면 A[i] ~ A[j]의 구간 합인 4 + 15 + 2 = 21을 구할 수 있다.

 

3. 특정 원소의 값을 수정하기

 

세그먼트 트리에서 배열의 원소값이 바뀐 경우 구간 합 트리를 수정해 줄 필요가 있다. 

이는 '2. 원하는 구간의 합 구하기' 단계와 비슷하게 진행된다.

 

diff = 바뀐 원소의 값 - 원래 원소의 값이라 하자.

   

    case 1. Root 노드부터 만약 노드의 범위가 바꾸려는 원소의 index를 포함하는 경우 노드 값에 diff를 더해주고 자식 노드를 탐색한다.

                  (만약 노드의 범위 == index라면 diff를 더해주고 return한다)

    case 2. 아닌 경우 아무 것도 하지 않는다. 

 

따라서 배열 A에서 A[4] = 5로 바꾸면 diff = 5 - 8 = -3이 되고 다음과 같이 구간 합 트리가 수정된다.

 

----------------------------------

                         39

----------------------------------

          19                            20           

----------------------------------

  10             9             15              5

----------------------------------

1    9        5    4        8    7        2    3      

----------------------------------

 

----------------------------------

                         36

----------------------------------

          19                            20           

----------------------------------

  10             9             15              5

----------------------------------

1    9        5    4        8    7        2    3      

----------------------------------

 

----------------------------------

                         36

----------------------------------

          19                            17           

----------------------------------

  10             9             15              5

----------------------------------

1    9        5    4        8    7        2    3      

----------------------------------

 

----------------------------------

                         36

----------------------------------

          19                            17           

----------------------------------

  10             9             12              5

----------------------------------

1    9        5    4        8    7        2    3      

----------------------------------

 

----------------------------------

                         36

----------------------------------

          19                            17           

----------------------------------

  10             9             12              5

----------------------------------

1    9        5    4        5    7        2    3      

----------------------------------

 

이를 코드로 구현하면 다음과 같다.

 

class SegmentTree{
    long[] tree;
    int size;
    SegmentTree(long[] A, int size){
        this.size = size * 4;
        tree = new long[this.size];
        this.init(A, 1, 1, size);
    }

    // 1. 구간 합 트리 만들기
    private long init(long[] A, int node, int start, int end) {
        if (start == end)
            return tree[node] = A[start];
        else {
            int mid = (start + end) / 2;
            return tree[node] = this.init(A, node * 2, start, mid) + this.init(A, node * 2 + 1, mid + 1, end);
        }
    }
    // 2. 원하는 구간의 합 구하기
    long sum(int start, int end, int i, int j, int node){
        if (j < start || end < i) return 0;
        else if (i <= start && end <= j) return tree[node];
        else{
            int mid = (start + end) / 2;
            return sum(start, mid, i, j, node * 2) + sum(mid + 1, end, i, j, node * 2 + 1);
        }
    }
    // 3. 특정 원소의 값을 수정하기
    void update(int start, int end, int i, int node, long diff){
        if (i >= start && i <= end){
            tree[node] += diff;
            if (start != end){
                int mid = (start + end) / 2;
                update(start, mid, i, node * 2, diff);
                update(mid + 1, end, i, node * 2 + 1, diff);
            }
        }
    }
}

 


reference

-https://blog.naver.com/ndb796/221282210534

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

 

 

 

 

 

 

 

 

 

 

 

관련글 더보기