9.1 Recursive Sorting Algorithms#
This week we’re switching gears a little bit. For the bulk of the course we’ve talked about different abstract data types and the data structures we use to implement them. Now, we’re going to talk about one very specific data-processing operation, which is one of the most fundamental in computer science: sorting. Just as we saw multiple data structures that could be used to represent the same ADT, we’ll look at a few different ways to implement sorting.
You’ve studied sorting before in CSC108; all of the basic sorting algorithms you probably saw—bubblesort, selection sort, insertion sort—were iterative, meaning they involved multiple loops through the list.[1] You probably also talked about how their running time was quadratic in the size of the list, so each of these algorithms sorts a list of size \(n\) in \(O(n^2)\) steps. (Why? Briefly, each involves \(n\) different loops, where each loop has between 1 and \(n\) iterations, and \(1 + 2 + 3 + \cdots + n = \frac{n(n+1)}{2}\).)
In this lecture, we’re going to use recursion to develop two faster sorting algorithms, mergesort and quicksort. These are both recursive divide-and-conquer algorithms, which is a general class of algorithms that use the following steps:
Split up the input into two or more parts.
Recurse on each part separately.
Combine the results of the previous step into a single result.
Where these two algorithms differ is in the splitting and combining: mergesort does the “hard” (algorithmically complex) work in the combine step, and quicksort does it in the divide step.
Mergesort#
The first algorithm we’ll study is called mergesort, and takes the “divide-and-conquer” philosophy very literally. The basic idea of this algorithm is that it divides its input list into two halves, recursively sorts each half, and then merges each sorted half into the final sorted list.
def mergesort(lst: list) -> list:
"""Return a sorted list with the same elements as <lst>.
This is a *non-mutating* version of mergesort; it does not mutate the
input list.
"""
if len(lst) < 2:
return lst[:]
else:
# Divide the list into two parts, and sort them recursively.
mid = len(lst) // 2
left_sorted = mergesort(lst[:mid])
right_sorted = mergesort(lst[mid:])
# Merge the two sorted halves.
return _merge(left_sorted, right_sorted)
The merge operation#
While this code looks very straightforward, we’ve hidden the main complexity in the helper function _merge
, which needs to take two lists and combine them into one sorted list.
For two arbitrary lists, there isn’t an “efficient” way of combining them.[2] For example, the first element of the returned list should be the minimum value in either list, and to find this value we’d need to iterate through each element in both lists.
But if we assume that the two lists are sorted, this changes things dramatically.
For example, to find the minimum item of lst1
and lst2
when both lists are sorted, we only need to compare lst1[0]
and lst2[0]
, since the minimum must be one of these two values.
We can generalize this idea so that after every comparison we make, we can add a new element to the sorted list.
This is the key insight that makes the _merge
operation efficient, and which gives the algorithm mergesort
its name.
def _merge(lst1: list, lst2: list) -> list:
"""Return a sorted list with the elements in <lst1> and <lst2>.
Precondition: <lst1> and <lst2> are sorted.
"""
index1 = 0
index2 = 0
merged = []
while index1 < len(lst1) and index2 < len(lst2):
if lst1[index1] <= lst2[index2]:
merged.append(lst1[index1])
index1 += 1
else:
merged.append(lst2[index2])
index2 += 1
# Now either index1 == len(lst1) or index2 == len(lst2).
assert index1 == len(lst1) or index2 == len(lst2)
# The remaining elements of the other list
# can all be added to the end of <merged>.
# Note that at most ONE of lst1[index1:] and lst2[index2:]
# is non-empty, but to keep the code simple, we include both.
return merged + lst1[index1:] + lst2[index2:]
Quicksort#
While quicksort also uses a divide-and-conquer approach, it takes a different philosophy for dividing up its input list. Here’s some intuition for this approach: suppose we’re sorting a group of people alphabetically by their surname. We do this by first dividing up the people into two groups: those whose surname starts with A-L, and those whose surnames start with M-Z. This can be seen as an “approximate sort”: even though each smaller group is not sorted, we do know that everyone in the A-L group should come before everyone in the M-Z group. Then after sorting each group separately, we’re done: we can simply take the two groups and then concatenate them to obtain a fully sorted list.
The formal quicksort algorithm uses exactly this idea:
First, it picks some element in its input list and calls it the pivot.
It then splits up the list into two parts: the elements less than or equal to the pivot, and those greater than the pivot.[3] This is traditionally called the partitioning step.
Next, it sorts each part recursively.
Finally, it concatenates the two sorted parts, putting the pivot in between them.
def quicksort(lst: list) -> list:
"""Return a sorted list with the same elements as <lst>.
This is a *non-mutating* version of quicksort; it does not mutate the
input list.
"""
if len(lst) < 2:
return lst[:]
else:
# Pick pivot to be first element.
# Could make lots of other choices here (e.g., last, random)
pivot = lst[0]
# Partition rest of list into two halves
smaller, bigger = _partition(lst[1:], pivot)
# Recurse on each partition
smaller_sorted = quicksort(smaller)
bigger_sorted = quicksort(bigger)
# Return! Notice the simple combining step
return smaller_sorted + [pivot] + bigger_sorted
It turns out that implementing the _partition
helper is simpler than the _merge
helper above: we can do it just using one loop through the list.
def _partition(lst: list, pivot: Any) -> tuple[list, list]:
"""Return a partition of <lst> with the chosen pivot.
Return two lists, where the first contains the items in <lst>
that are <= pivot, and the second is the items in <lst> that are > pivot.
"""
smaller = []
bigger = []
for item in lst:
if item <= pivot:
smaller.append(item)
else:
bigger.append(item)
return smaller, bigger