Interpretability of LLMs is important in many real world situations. For instance, there is an EU law that says if a bank's model refuses a loan then the client has the right to know why. But neural nets are notoriously hard to interpret.
There are a few techniques:
Activation Patching: this is basically A/B testing a neural net where the activations from a clean prompt are transplanted to that of a corrupted prompt and we check if that fixes the corruption.
Logit lens: take the residual stream vector (that is, the vector that represents the working memory) and apply it to the vocabulary logits at each layer. This way, we can see where an output word (probably) starts to form.
Sparse Auto Encoders: the auto encoder is trained on the model's input and output, and deliberately made sparse for low values so prominent features become clear.
The neural net itself
Most of the Java code that iterates over the tokens lives in InferenceEngine.generateTokensGPULlama. In turn, this delegates to the Tornado graphs that are executed in TornadoVMMasterPlan.tornadoVMForwardExecuteLayered. Note that the positionHolder is an IntArray that has a single element.
However, the real heavy lifting is done in LlamaFP16FFNLayers.setupSingleFFNLayer and all the functions in TransformerComputeKernelsLayered it uses. Note, that it doesn't call them. Instead, it builds them up into a graph of lazily evaluated functions that are executed in the code mentioned above.
Note that values whose name ends with Cache are a distillation of all that has come before.
The functions that are called are (in this order):
reductionOneBlockWithLayer
reductionFinalNormalization
fusedQKVMatmulX
ropeRotationWithCacheCopy
processHeadsFlashAttention
processHeadsParallel
matrixVectorGenericWithResidual
reductionOneBlockWithLayer
reductionFinalNormalization
fusedRmsNormFFNGateUp
matrixVectorGenericWithResidual
The italicised functions configure the attention.
Let's go through them:
reductionOneBlockWithLayer calculates the root mean square of the input and takes its inverse. This is calculated per work group.
reductionFinalNormalization reduces these calculations of all the workgroups to a single value.
fusedQKVMatmulX multiplies the vector x by the vectors wq, wk, and wv, putting the results in q, k, and v respectively.
ropeRotationWithCacheCopy here we see a standard, two-dimensional rotation on pairs of elements in q and k. The angle we rotate by is a function of both the word's position and which attention head this is for (which is just the global thread ID mod the head size). Note we're imposing a geometry on vectors q and k that doesn't exist before this point.
processHeadsFlashAttention this calculates
softmax(Q KT) V. It uses
tiling [previous post] as this is a massive matrix operation.
This is done for all the FFN layers plus similar calls in LogitsFP16Layer.setupLogitsTaskGraph for the logit layer. Here, the functions are (in this order):
reductionOneBlockWithLayer
reductionFinalNormalization
mapContextWithQuantizeLogits
matrixVectorGeneric
Note that these are all called for each of the layers. There appear to be 16 for Llama.
There are more layers but this is enough for my first post on the subject. More in Part 2.
Rope
Upon loading the model, RoPE.precomputeFreqsCis is called when just before the weights are loaded.
We do a 2-d rotation on pair elements of the vector, imposing a geometry that was not there before. The choice of a 2-d rotation is arbitrary but it's computationally the cheapest dimension to rotate as the number of elements in a rotation matrix is d2 (where d is the number of dimensions we're working in).
No comments:
Post a Comment