# Closest pair of points in Python (divide and conquer): the quick implementation

## Computing minimum distance between 2 points on a 2d plane

Given 2 list of points with x and respective y coordinates, produce a minimal distance between a pair of 2 points.

Every battle with a hardcore algorithm should start somewhere. I suggest reading ** Cormen et all “Introduction to Algorithms”, 3rd edition (Section 33.4)**, but any decent book will do.

We start from a naive implementation of divide-and-conquer approach to the closest pair of points problem:

Let us suppose that we have 2 lists of size n as our inputs: x and y, which correspond to pairs of points (x1,y1) … (xn,yn), where n is number of points.

First, let’s look at the following function:

**def solution**(x, y):

a = list(zip(x, y)) # This produces list of tuples

ax = sorted(a, key=**lambda **x: x[0]) # Presorting x-wise

ay = sorted(a, key=**lambda **x: x[1]) # Presorting y-wise

p1, p2, mi = *closest_pair*(ax, ay) # Recursive D&C function

**return **mi

Here we address the concept of presorting. As noted in the book,

Note that in order to attain the O(n * lg (n)) time bound, we cannot afford to sort in each recursive call; if we did, the recurrence for the running time would be T (n) = 2T(n/2) +O(n*lg (n)), whose solution is T (n) = O(n * lg(n)²).

Therefore, presorting outside of function that will be called recursively allows to implement the solution in smaller time complexity.

Let’s look at the recursive call (with the appropriate comments):

def closest_pair(ax, ay):

ln_ax = len(ax) # It's quicker to assign variable

ifln_ax <= 3:

returnbrute(ax) # A call to bruteforce comparison

mid = ln_ax // 2 # Division without remainder, need int

Qx = ax[:mid] # Two-part split

Rx = ax[mid:] # Determine midpoint on x-axis midpoint = ax[mid][0]

Qy = list()

Ry = list()

forxinay: # split ay into 2 arrays using midpoint

ifx[0] <= midpoint:

Qy.append(x)

else:

Ry.append(x) # Call recursively both arrays after split (p1, q1, mi1) =closest_pair(Qx, Qy)

(p2, q2, mi2) =closest_pair(Rx, Ry) # Determine smaller distance between points of 2 arraysifmi1 <= mi2:

d = mi1

mn = (p1, q1)

else:

d = mi2

mn = (p2, q2) # Call function to account for points on the boundary (p3, q3, mi3) =closest_split_pair(ax, ay, d, mn) # Determine smallest distance for the arrayifd <= mi3:

returnmn[0], mn[1], d

else:

returnp3, q3, mi3

The implementation above is done according to the book. However, during the debugging of the algorithm, I’ve found a peculiar feature. If we were to substitute the midpoint split logic to:

`qx = set(Qx)`

Qy = list()

Ry = list()

**for **x **in **ay:

**if **x **in **qx:

Qy.append(x)

**else**:

Ry.append(x)

the code would actually run a little bit faster. I won’t dive into low-level details of it, though a curious one should compare the speeds of comparison

`x[0] <= midpoint`

to

`x `**in **qx

for a set(). That’s the only reason I can think of.

Now, that’s where it gets interesting. First, the *brute(ax)* function:

**def brute**(ax):

mi = *dist*(ax[0], ax[1])

p1 = ax[0]

p2 = ax[1]

ln_ax = len(ax)

**if **ln_ax == 2:

**return **p1, p2, mi

**for **i **in **range(ln_ax-1):

**for **j **in **range(i + 1, ln_ax):

**if **i != 0 **and **j != 1:

d = *dist*(ax[i], ax[j])

**if **d < mi: # Update min_dist and points

mi = d

p1, p2 = ax[i], ax[j]

**return **p1, p2, mi

Let us discuss that in brief. Why mi = distance between first two points from the list? Why not a random and large number? Well, it saves us a computation on each of the many calls to the *brute *function. That’s a win. Furthermore, if len(ax) == 2, we’re done, result can be returned.

Second important point concerns ranges of our two cycles, which need to be used in case of 3 points (recall that *brute* is called only if len(ax) ≤ 3). Why do we not need to iterate over len(ax) points for i index? Because we are comparing two points: ax[i] and ax[j], and j is in range from i+1 to len(ax). It means, that we’ll compare all the points in len(ax) anyway. Furthermore, conditions on j index mean that we won’t compare points twice: dist(a[1], a[3]) and dist (a[3], a[1]) as well as dist(a[2], a[2]) situations are not allowed because of the boundaries. It speeds up the algorithm at least 2 times (as opposed to simply having 2 cycles of len(ax)).

Back to our first point. If condition inside loops saves us extra comparison computation.

Distance function (*dist*) is nothing special:

**import **math

def dist(p1, p2):

**return **math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)

Finally, one of the most interesting pieces, a function, responsible for finding a closest pair of points on a splitline, *closest_split_pair:*

def closest_split_pair(p_x, p_y, delta, best_pair):

ln_x = len(p_x) # store length - quicker

mx_x = p_x[ln_x // 2][0] # select midpoint on x-sorted array # Create a subarray of points not further than delta from

# midpoint on x-sorted array s_y = [xforxinp_yifmx_x - delta <= x[0] <= mx_x + delta] best = delta # assign best value to delta

ln_y = len(s_y) # store length of subarray for quickness

foriinrange(ln_y - 1):

forjinrange(i+1, min(i + 7, ln_y)):

p, q = s_y[i], s_y[j]

dst = dist(p, q)

ifdst < best:

best_pair = p, q

best = dst

returnbest_pair[0], best_pair[1], best

Again, the salt lies in ranges of 2 cycles. They are produced using ideas similar to ones used in *brute* function, with one important distinction. The upper boundary on j index is min(i+7, ln_y) for reasons discussed in ** Correctness **chapter

**In short: it is enough to check only seven points following each point on the s_y subarray. You should really look through the proof of correctness, because it explains a lot better this ‘trick’ that allows for great running speed increase.**

*of Corman et all.***P.S.: tips on debugging and testing**

Unit tests are mandatory. *IDE PyCharm* (Ctrl + Shift + T for creating a unit test for method) is recommended. Also, additional reading on stress testing is advised.

I used the following code to create a great test case for testing purposes:

**import **random

def test_case(length: int = 10000):

lst1 = [random.randint(-10**9, 10**9) **for **i **in **range(length)]

lst2 = [random.randint(-10**9, 10**9) **for **i **in **range(length)]

**return **lst1, lst2

It took about 40 seconds to run initially on my Intel i3 (2 cores, 4 processes), ~2.3 GHz, 8 Gb RAM, SSD (~450 MB/s read/write), which dropped to about 20–30 secs after some optimizations I mentioned.

Another great tool for debugging purposes was my friend’s library of convenient timers (which I posted to my Github after some changes):

It helped to time functions using convenient wrappers, and examples are built in code.

I used wrappers over the functions described above, ran the test case and collected the prints of runtime to json file. Later I passed the results over to SQLite database and used the aggregation functions to get average runtime for each function. I performed same procedure again after adding optimizations and was able to observe % change between the average runtimes of functions to understand whether the optimization improved runtime of a specific function (overall runtime could be compared just from running the unittest example above). I designed this procedure for deep understanding of results and is not necessary for general debug.

Good luck and contact me for extra details on the algorithm or for other suggestions: *andriy.lazorenko@gmail.com*

P.S.: this story is a part of my series on algorithmic challenges. Check out other cool algorithms decomposed with tests and jupyter notebooks!