A nice overview that's detailed but not too intricate is here [blog of SteelPh0enix AKA Wojciech Olech]
Note that when using a fully trained LLM, things are conceptually much simpler because it is more or less just a feedforward network. That is, the weights are immutable. State lives outside of the ANN and is updated by the output after each token runs through the feedforward network.
Attention!
The attention mechanism of an RNN takes the encodings and for each token, augments the input vector both forwards and backwards. "The rationale behind this is to capture additional information since current inputs may have a dependence on sequence elements that came either before or after it in a sentence, or both." [1] The two vectors are then concatenated to make one long vector. "We can consider this concatenated hidden state as the annotation of the source word since it contains the information of the jth word in both directions." [1]
The self-attention mechanism for each word has three vectors: query, key and and value. It compares the query of each word to the key of the others. This process is done in parallel ("multi-head attention") using different Q/K/V weights and the results combined.
The transformer architecture has superceded RNNs.
The cat sat on the mat
Conceptually, q, represents the query (eg, the word "sat" is looking for something to in the sentence to do the sitting); k is the key where the word "cat" is saying I am a noun that can sit; and v links the verb looking for a noun and the noun looking for a verb.
This is similar to when we apply singular value decomposition to a document/term space and create a concept space.
Basic self attention
Imagine T input vectors x(i).
Tokens are embedded in a space of size d.
The T output vectors of self-attention are vectors z(i).
These vectors are calculated thus:
z(i) = Σj=1T αij x(i)
where α is a matrix of the dot products of all the x vectors with softmax applied to it (remember, softmax does not change the relative sizes of the logits but does convert them into probabilities).
Multi-headed attention is the same algorithm calculated for multiple heads (that is, mutliple swimlanes of datat that represent nuances in language structure).
Self attention with learnable parameters
Here we project each x vector onto Uk, Uv and Uq. Note that Q=xUq etc and Uq etc are fixed for a feed forward network. That is, somebody has done the hard work of calculating them during training.
The projections onto q and v are then multiplied together and the result is put into a matrix ωij where i is a token and j any other token.
The matrix ω is divided by (typically) √d and softmaxed.
That is:
Attention(Q, V, K) = softmax ( Q KT / √dk ) V
Tiling
Tiling is a technique when performing matrix operations on data that won't fit into memory.
"With naïve algorithm, to compute each element of the result, we gonna need to fetch S elements from both matrices. The output matrix has S2 elements, therefore the total count of loaded elements is 2S3." [Penny Xu's blog]
This breaks down as 2 [vectors - one from each matrix] * S [the size of those vectors] * S2 [the number of elements that are the result of all these dot products - that is, the size of the matrix].
"With 32×32 tiling, to compute each 32×32 tile of the result, we gonna need to fetch S/32 tiles from both matrices. The output size in tiles is (S/32)2, the total count of loaded [elements] is 2*(S/32)3. Each 32×32 tile contains 322 elements, the total count of loaded elements is therefore (322)*2*(S/32)3 = (2/32)*S3. Therefore, the tiling reduced global memory bandwidth by the factor of 32, which is a huge performance win." [ibid]
In other words, having broken the matrix down into a grid of size 32×32, and each block in the grid involving 2*(S/32)3 operations, the total number of operations is this first number times the second - that is, (2/32)*S3.
Note that the resulting matrix is tiled also. So, if we doing the matrix multiplication of C=AB, the total memory needed is one tile of each of A, B and C.
There's a further optimization. If we're looking for the maximum value (which is very common in neural nets where we typically employ the softmax function), we only need to store one value per tile per column/row.
[1] Machine Learning with PyTorch and SciKitLearn.
No comments:
Post a Comment