Saturday, February 27, 2016

Locality Sensitive Hashing in Tweets


Spark's DIMSUM algorithm might have a limit to how much data it can process but the code for it came from Twitter who process gazillions of tweets everyday. How do you apply it to huge amounts of data?

Well, one way is to break the data down into smaller chunks and one way to do that without losing its meaningfulness is to use locality sensitive hashing. The best (but rather verbose) description I've found of LSH is here, a chapter from a free book on data mining.

(Note: Ullman uses the term "characteristic matrix" in the document differently from how perhaps a mathematician would use it. A characteristic matrix in maths is λ I - A where A is an n x n matrix, I is the n x n identity matrix and λ is an eigenvalue. Solving the characteristic equation, |λ I - A| = 0, gives you at most n real roots and allows you to calculate the eigenvectors, v, by substituting the discovered eigenvectors into the homogeneous system of equations, (λ I - A) v = 0 ).

To summarise, the key points follow.

Jaccard Similarity

First, we define something called the Jaccard similarity between two sets. This is simply:

s  =  number of shared elements between the sets
number of total elements in the sets

Next, we convert the two sets into a matrix where the columns represent the set and the rows represent the elements in the set. The cell value where they meet simply represents whether that term is in that set (a value of 1) or not (a value of 0).

Minhashing

Next, we define a function that gives us the row of the first element in a given column that have value 1 (going from top to bottom). We call this minhashing. Typically, we've reordered the rows of the original matrix randomly.

The Minhash and Jaccard relationship

The chances that two columns have the same minhash result is actually the Jaccard similarity. The proof of this follows.

Given two columns, say they both have 1 in x rows; only one of them has a non-zero value in y columns; and we ignore the rest. If we pick any non-zero row at random then the probability that both sets have 1 is unsurprisingly:

 x 
x + y

which is actually exactly the same as the Jaccard similarity, s.

Minhash Signatures

Now, let's make another matrix with a column per set (as before) but this time each row represents one of n random functions. The function simply takes the row index and returns a deterministic value. The cell values are all set to positive infinity but that will change.

For each function, we use the index of the row in the original matrix to generate a result. For each set in the original matrix in that row that has a non-zero value, we place the function's value in the corresponding cell in our new matrix if it less than what is already there.

If there are m sets, this results in an n x m matrix that is our (much reduced) signature.

Banding

Now let's break our signature matrix into b bands of r rows each.

The probability that all r rows within a band have the same value is sr.

The probability that at least one row is not the same is (1 - sr).

The probability that all b bands have at least one row that is not the same is (1 - sr)b.

Therefore, the probability that at least one of the b bands has all r rows the same is 1 - (1 - sr)b.

Which is the probability that we have a pair of sets that may be similar.

The nice things about it is that this describes an S-shaped curve no matter what the values for s, r and b. This is good as the probability that either a point on it is clearly probable or clearly improbable is maximized.

For b = 10, the probabilities for r = [2, 4, 6, 8] look like this:



Distance functions

There are many types of distance measurements. Euclidean distance is the one we're taught most in school but there are others. The only rule is that for any distance function, d, that gives the distance between x and y, the following must hold:
d(x, y) >= 0 
d(x, y) = 0 iff x == y 
d(x, y) = d(y, x) [they're symmetric] 
d(x, y) <= d(x, z) + d(z, y)  [the triangle inequality]
Nothing too clever there. They're just common sense if you were navigating around town. The thing to bear in mind is that they apply to more than just the 3D space in which we live.

Family of functions

We define a family of functions as a collection of functions that say "yes, make x and y a candidate
pair,” or "no, do not make x and y a candidate pair unless some other function concludes we should do so.” Minhashing is one such family.

Furthermore, we define the term (d1 , d2 , p1 , p2)-sensitive for such families to describe the curve they generate. We define
d1 < d2 
p1 > p2
where d1 and d2 are the results of a given distance function d(x, y); p1 and p2 are minimum probabilities that two functions agree when d(x, y) < d1 and the maximum probability they agree when d(x, y) > d2 respectively.

This gives us a nice short-hand way to describe a complicated graph.

Amplifying the effect

It is desirable to maximize the steepness of the S-curve to reduce the number of false negatives and false positives. We can do this by applying a second family of functions to the first. We could demand all the bands agree for 2 points (x, y), that is an AND operator. Or we could demand at least one of the bands agree, that is an OR operator. Each will change the curve in a different way depending on what we want.

This is analogous to the banding (see above) and the maths for this is the same only this time instead of s, we're dealing with p, the probability two functions give the same answer for a row.

The downside to this is we increase the demand on our CPUs. In large datasets, a small increase in CPU is magnified proportional to the amount of data we have to run the algorithm over so it's worth bearing this in mind.

Code

I stole the Python code from here to generate the graphs and slightly modified it to below:

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import PolyCollection
from matplotlib.colors import colorConverter
import matplotlib.pyplot as plt
import numpy as np

# from http://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html#line-plots

fig = plt.figure()
ax = fig.gca(projection='3d')


def cc(arg):
    return colorConverter.to_rgba(arg, alpha=0.6)


def p(s, r):
    b = 10
    return 1 - ((1 - s**r) ** b)


xs = np.arange(0, 1, 0.04)
verts = []
zs = [2.0, 4.0, 6.0, 8.0]
for z in zs:
    ys = map(lambda x: p(x, z), xs)
    ys[0], ys[-1] = 0, 0
    verts.append(list(zip(xs, ys)))

poly = PolyCollection(verts, facecolors=[cc('r'), cc('g'), cc('b'),
                                         cc('y')])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')

ax.set_xlabel('s')
ax.set_xlim3d(0, 1)
ax.set_ylabel('r')
ax.set_ylim3d(min(zs), max(zs))
ax.set_zlabel('p')
ax.set_zlim3d(0, 1)

plt.show()


No comments:

Post a Comment