부부의 코딩 성장 일기

LeetCode 96(Unique Binary Search Trees, Python) 본문

Algorithm/LeetCode

LeetCode 96(Unique Binary Search Trees, Python)

펩시_콜라 2024. 2. 8. 19:00

1. 문제 링크 

 

Unique Binary Search Trees - LeetCode

Can you solve this real interview question? Unique Binary Search Trees - Given an integer n, return the number of structurally unique BST's (binary search trees) which has exactly n nodes of unique values from 1 to n.   Example 1: [https://assets.leetcode

leetcode.com

 

2. 문제 설명 

  • 정수 n이 주어졌을 때, 1부터 n까지의 n개의 node를 사용하여 만들 수 있는 BST의 개수를 반환
    • leetcode에서 Unique Binary Search Trees II가 해당 문제보다 앞에 있는데, 
    • 해당문제에서는 전체 가능한 Tree를 list에 append하여 반환하였다면, 
    • 이번에는 가능한 Tree의 개수를 int형태로 반환하는 문제이다. 
  • 예시) n=3 이면 아래처럼 총 5가지의 경우의 수가 가능하므로 5를 반환
    • BST란 노드의 왼쪽 subtree에는 노드의 값보다 작은 값만 있는 노드만 포함되어야 하고, 
    • 반대로 오른쪽 subtree에는 노드의 값보다 큰 값만 있는 노드만 포함이 되어야 한다. 
    • 그래서 예를들어 2의 왼쪽노드에 3, 오른쪽노드에 1이 들어가는 것은 불가능.

 

3. 처음 풀이

  • 처음에는 기존 Unique Binary Search Tree II에서 가능한 모든 경우의 수를 list에 저장했으므로, 
  • 해당 list의 length를 반환하는 것으로 return에 len만 추가하여 반환하였으나, time limit이 발생했다. 
    • 해당 코드에서 hash table을 쓰면 결국 잘 작동하긴 했지만, 굳이 tree 구조를 쌓을 필요는 없어서, 다르게 접근하였다.
class Solution:
    def numTrees(self, n: int) -> int:

        def subtree(left,right):
            if left == right: 
                return [TreeNode(left)]
            
            if left > right:
                return [None]
            
            res = []
            for val in range(left,right+1):
                for lefttree in subtree(left,val-1):
                    for righttree in subtree(val+1,right):
                        root = TreeNode(val, lefttree, righttree)
                        res.append(root)
            
            return res 

        return len(subtree(1,n))

 

  • 우선 n=5인 경우를 생각해보면,
    • 만약 root node의 값이 3이라면,
      • 왼쪽 subtree는 [1,2]로 구성이 되고
      • 오른쪽 subtree는 [4,5]로 구성이 된다.
        • 즉, 왼쪽의 경우의 수와 오른쪽의 경우의 수를 곱한 것이 root node가 3일 때 만들어질 수 있는 tree의 개수가 된다.
      • 이에 이를 root node가 1일때부터 5일때까지의 값을 구해서 더하면 최종 output이 만들어지게 된다. 
      • 코드로는 output += subtree(left,val-1) * subtree(val+1,right)로 구현하였다. 
  • tree의 경우 계속 가장 작은 단위로 break down 하게 되면, node가 1개일 때가 가장 작은 단위가 되고,
    • 이에 left = right이면 1을 반환하고, 
    • 그리고 만약 left>right이면 null밖에 경우의수가 남지 않으므로 마찬가지로 1을 반환한다. 
    • 나머지 case에 대해서 위 n=5일 때의 예시처럼 for문을 돌리면 최종 경우의 수가 반환이 된다. 
  • 하지만 위 로직으로만 제출을 하면 마찬가지로 timelimit이 뜬다. → 이전에 이미 계산했던 subtree에 대해서 hash table을 만들어서 저장하는 것이 필요.
    • 이에 hist = {}라는 빈 dictionary를 만들고,
    • for문이 끝날 때 과거 history를 저장해두었다. hist[(left,val-1)]  = subtree(left,val-1)과 같은 식
    • 그래서 만약 (left,right)가 hist에 있다면, 해당 값을 반환하는 것으로 중복 계산을 줄이게 되면 submission통과!
class Solution:
    def numTrees(self, n: int) -> int:

        hist = {}

        def subtree(left,right):
            if left == right: 
                return 1
            
            if left > right:
                return 1

            if (left,right) in hist:
                return hist[(left,right)]

            output = 0
            for val in range(left,right+1):
                output += subtree(left,val-1) * subtree(val+1,right)
                hist[(left,val-1)] = subtree(left,val-1)
                hist[(val+1,right)] = subtree(val+1,right)

            return output            

        return subtree(1,n)

 

4. 다른 풀이

  • 해결하고 뿌듯했는데, 좀 더 단순한 풀이(이해하는 과정은 단순하지 않았다.)가 있었다.
    • recursive하게 계산하는 것이 아닌, dynamic programming으로 접근. 
  • 우선 n+1개만큼 0으로 채운 dp라는 array를 만든다. (여기서 n+1개인 이유는 subtree의 개수가 0개일 수도 있기 때문)
    • node개수가 0일 때의 경우의 수는 1
    • node개수가 1일 때의 경우의 수 또한 1이므로, 
    • 우선 초기값에 dp[0] = 1, dp[1] = 1을 셋팅한다.
  • 그리고 dp[2] 부터 내가 원하는 n까지의 값을 채워줘야하는데, 
    • 이 때 dp[2]같은 경우, 왼쪽 서브트리의 경우의 수 * 오른쪽 서브트리의 경우의 수의 곱으로 계산이 되고,
      • 여기서 dp[j]가 왼쪽 서브트리의 경우의 수, dp[i-1-j]가 오른쪽 서브트리의 경우의 수가 된다. 
      • 즉, dp[2]를 식으로 표현하게 되면
        • dp[0] * dp[1] + dp[1]*dp[0]인데, 해당 코드를 해석하게 되면, 
        • 왼쪽 서브트리가 0개이고, 오른쪽 서브트리가 1개로 구성될 수 있는 Tree의 경우의 수와 
        • 왼쪽 서브트리가 1개이고, 오른쪽 서브트리가 0개로 구성될 수 있는 Tree의 경우의 수를 곱하게 되는 것이다. 
    • 즉, 원하는 n까지 dp array값을 채운 뒤, dp[n]을 반환하면, 해당 시점에서의 경우의 수를 구할 수 있다.

 

class Solution:
    def numTrees(self, n: int) -> int:
        dp = [0] * (n + 1)
        dp[0] = 1
        dp[1] = 1

        for i in range(2, n + 1):
            for j in range(0, i):
                dp[i] += dp[j] * dp[i - 1 - j]
        
        return dp[n]

 

5. 배운 점

  • 이전 값을 참조해야 할 때, hash table을 dictionary형태로 구성할 수도 있지만, Dynamic Programming의 접근으로 구성하는 아이디어도 떠올릴 수 있어야겠다.