Friday, April 18, 2008

k nearest neighbors speed up

I was working on implementing the k nearest neighbors algorithm in Python and very quickly discovered that my code was slow. Really slow. Unacceptably slow. For a training set of 1000 items with 150 attributes and a k of 4 it took about 30 seconds to classify 1000 testing items on a fast machine. When I tried to run it on real data I actually had a system administrator kill my process because he thought it was in an infinite loop. Obviously, my code needed some work.

After profiling I found that the most time (by far) was being spent on finding the neighbors. Reading in the data initially and deciding on the most likely classification once the neighbors were found were pretty minor, comparatively.

My original implementation was something like the following:

#seed with first k items
neighbors = [(distance(x, item), i) for i, x in enumerate(training[:k])]

#look for closer items
for i, x in enumerate(training[k:]):
dis = distance(x, item)
if dis < neighbors[0][0]:
neighbors[0] = (dis, i)

I computed the distance manually as below:

def distance(x, y):
d = 0
for i in range(0, len(x)):
d += (x[i]-y[i])**2
#d = sqrt(d) #can be cut for speed
return d

One of the tricks you'll hear again and again when optimising python is to let C code (compiled or built in libraries) do as much work as you can. I set up NumPy and stored each training and testing item as a NumPy array, and out came this:

def distance(x, y):
d = x-y
return dot(d, d)

The new version ran about twice as fast as the original, but still not fast enough to really be usable. I toyed with an alternate version that calculated the distance for each training item up front, sorted based on distance, and pulled out the first k items. It had about the same time for smallish training sets, but got worse and worse as the training data got larger. It was becoming obvious to me that I'd need to do some real problem solving.

I spent a while looking at the NumPy documentation and reconsidering the way I thought about the data. I knew there had to be a way to do the calculation in one (hopefully fast) set of operations, rather than all this silly looping and sorting/comparing. Then it occurred to me that what I really had was not a list, but a matrix of training data where each row was an individual item. All I really wanted to do was subtract the testing item from each row, square every entry in the resulting matrix, and sum the rows. The result would be an array of distances between the item and all rows in the matrix.

I did just that. The two major perks to this were that it could be done mostly (if not entirely) in compiled NumPy code and that NumPy had a quick argmax function that would give me the index of the nearest neighbor's row in the training matrix (which would correspond to the index in the list of classifications.) What came out of that was the solution below:

#get array of distances between item and all others
distances = sqrt(((training-item)**2).sum(axis=1))

#get closest item to start with
closest = distances.argmin()
neighbors = (closest, )

#set current closest to large number and find next
for i in range(1, k):
distances[closest] = sys.maxint
closest = distances.argmin()
neighbors += (closest, )

It now runs at twice the speed of the version using the original algorithm with optimized distance function. For those keeping track, that's four times faster than my first attempt. It's definitely not as fast as having written it all in pretty much any compiled language, but it's snappy. I even added back in that slow square root operation because calling it once does it for each item in the final array and there's plenty of time to spare still.

Oddly enough, what I'm doing isn't really all that different from the original design. What made the difference was being open to viewing the data as something other than what I'd initially assumed it to be. It wasn't a list of training items to be manually compared to testing items one at a time, it was a matrix of points in a multidimensional space. I imagine the distances are still calculated in some sort of loop(s), but it isn't in my own slow Python code. It's in the trustworthy hands of the people behind NumPy, because their tool is the best one for the job.

As for further application, I think this is a lesson in search problems in general. Many searches can be optimized using mathematical properties of the data's domain or nifty data structures. For extra credit, check out kd Trees.