Binary Search Trees insert and search

Binary Search Tree

In this post, we will become masters of Binary Search Trees. Inserting, Searching and Deleting are all done in log(n) time complexity.

Binary Search Trees are nice because nodes have an order. Because they have an order, when we traverse the tree we save ourselves a lot of time. Which means, when we insert a node, the runtime average complexity is log(n), same thing with when we search for a node, it is log(n), and same thing with when we delete a node, the time complexity is log(n). Where n is the number of nodes in the tree.

image

Define Node class

# this code makes the tree that we'll traverse

class Node(object):

    def __init__(self,value = None):
        self.value = value
        self.left = None
        self.right = None

    def set_value(self,value):
        self.value = value

    def get_value(self):
        return self.value

    def set_left_child(self,left):
        self.left = left

    def set_right_child(self, right):
        self.right = right

    def get_left_child(self):
        return self.left

    def get_right_child(self):
        return self.right

    def has_left_child(self):
        return self.left != None

    def has_right_child(self):
        return self.right != None

    # define __repr_ to decide what a print statement displays for a Node object
    def __repr__(self):
        return f"Node({self.get_value()})"

    def __str__(self):
        return f"Node({self.get_value()})"

from collections import deque
class Queue():
    def __init__(self):
        self.q = deque()

    def enq(self,value):
        self.q.appendleft(value)

    def deq(self):
        if len(self.q) > 0:
            return self.q.pop()
        else:
            return None

    def __len__(self):
        return len(self.q)

    def __repr__(self):
        if len(self.q) > 0:
            s = "<enqueue here>\n_________________\n"
            s += "\n_________________\n".join([str(item) for item in self.q])
            s += "\n_________________\n<dequeue here>"
            return s
        else:
            return "<queue is empty>"

Insertion

Using iteration: Let’s build some intuition around this. Let’s say we have a tree, and there is a value coming in. Where should we place this value? Well, we know the answer, start at root and do a comparison. Three outcomes of comparison:

  • if the value is same, then replace it, or just print out saying value already exists.
  • if the value is less than root, then go left
  • if the value is more than root, then go right

Put that logic in a loop. Ok, so what should we use as the loop variant? and when should we stop looping?. The loop variant should be the actual node in the tree, either left-node or right-node depending on the comparison. In each iteration of the loop, we move either left or right. In each iteration of the loop, we check the node value - this gives us a clear indication that in each iteration of the loop, we need to modify the value of the node. Secondly, at what point should we break out of the loop? Well, after we reach a point where the current node does not have a child node, at which point we just insert the new value there and break.

Using recursion: Time and again, when working with trees, we encounter this type of looping. Instead of looping and changing the node value, we can instead use recursion to solve the same problem. One of the tenets of a recursive function is that we modify the input to the recursive function in some manner before we call it again. In this case, we modify the node - same thing that I explained for iteration. Leaving just the first value insertion, i.e, root value, we can use recursion to insert the rest of the values.

class Tree():
    def __init__(self):
        self.root = None

    def set_root(self,value):
        self.root = Node(value)

    def get_root(self):
        return self.root

    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1 # traverse left
        else: #new_node > node
            return 1  # traverse right

    def insert_with_loop(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return

        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node's value
                node.set_value(new_node.get_value())
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping

    def insert_with_recursion(self,value):

        if self.get_root() == None:
            self.set_root(value)
            return
        #otherwise, use recursion to insert the node
        self.insert_recursively(self.get_root(), Node(value))

    def insert_recursively(self,node,new_node):
        comparison = self.compare(node,new_node)
        if comparison == 0:
            # equal
            node.set_value(new_node.get_value())
        elif comparison == -1:
            # traverse left
            if node.has_left_child():
                self.insert_recursively(node.get_left_child(),new_node)
            else:
                node.set_left_child(new_node)

        else: #comparison == 1
            # traverse right
            if node.has_right_child():
                self.insert_recursively(node.get_right_child(), new_node)
            else:
                node.set_right_child(new_node)


    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node)
            else:
                s += "\n" + str(node)
                previous_level = level


        return s
tree = Tree()
tree.insert_with_loop(5)
tree.insert_with_loop(6)
tree.insert_with_loop(4)
tree.insert_with_loop(2)
tree.insert_with_loop(5) # insert duplicate
tree.insert_with_loop(1)
print(tree)
Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
Node(1) | <empty>
<empty> | <empty>
tree = Tree()
tree.insert_with_recursion(5)
tree.insert_with_recursion(6)
tree.insert_with_recursion(4)
tree.insert_with_recursion(2)
tree.insert_with_recursion(5) # insert duplicate
tree.insert_with_recursion(1)
print(tree)
Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
Node(1) | <empty>
<empty> | <empty>

Define a search function that takes a value, and returns true if a node containing that value exists in the tree, otherwise false.

class Tree():
    def __init__(self):
        self.root = None

    def set_root(self,value):
        self.root = Node(value)

    def get_root(self):
        return self.root

    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1

    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return

        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping

    """
    implement search
    """
    def search(self,value):

        # get the root
        node = tree.get_root()

        if node is None:
            return False

        while(node):

            # compare
            if self.compare(node, Node(value)) == 0:
                print("Found")
                return True
            elif self.compare(node, Node(value)) == -1:
                # check if there is a left child
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    print("Not Found")
                    return False
            else:
                # check if there is a right child
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    print("Not Found")
                    return False

    def search_v2(self,value):
        node = self.get_root()
        s_node = Node(value)
        while(True):
            comparison = self.compare(node,s_node)
            if comparison == 0:
                return True
            elif comparison == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    return False
            else:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    return False

    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node)
            else:
                s += "\n" + str(node)
                previous_level = level


        return s

tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(2)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")
print(tree)
Not Found
Found

search for 8: False
search for 2: True

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>

deletion

check out this explanation here

Deletion from a binary search tree is more involved than insertion and searching. Read the above.