I've been applying HuggingFace's ModernBert to a sequence of medical codes in the hope we can identify people who may have type 2 diabetes before even their doctors. The results have been pretty good and met expectations (the same as an American study; this is UK data).
But what I'm surprised about is that a very simple approach performed as well as the more complicated one recommended by our American-based colleagues. Their recommendation was to use transfer learning on their model. This simply didn't work as their model used a medical encoding system called ICD10 when the British more commonly use the richer SNOMED. The theory was with fine tuning, the model would learn. It didn't.
As an aside, mapping one encoding to another is a non-trivial task. Avoid it if you can.
Instead, we pretrained BERT ourselves exclusively on our data using a Masked Language Model. This is where random words are omitted from a sequence and the model is asked to guess them. Having built the model, we then finetune it with the medical code sequence for those with and without T1D seeing if BERT can classify them.
Note that in this case, there is no pretrained BERT. We build it entirely from just our data. A sequence of medical codes is a very simple 'language' so you don't need massive compute resources to train a model. A single Tesla V100 with 16gb of VRAM was sufficient.
Although this achieved good results, it took literally days. The surprising things was that when we (perhaps naively) treated the sequence of codes as a sentence and used a plain BertForSequenceClassification object to train the model in just one step (no pretraining and finetuning) we got the same results.
Anyway, here are some notes I made along the way:
The neural net
num_hidden_layers impacts accuracy and the model's ability to generalize. Its default is 12. Recommended values depend on use case. Less than 6 is good for resource-constrained environments; 6-12 for fine-tuning; 12-24 for large pretraining.
ChatGPT suggests for BERT classification, final loss should be 0.2-0.4. "BERT can overfit fast, especially with small datasets."
per_device_train_batch_size is purely a runtime optimization.
call model.eval() and put all access to the model in a with torch.no_grad() block.
Note that BERT does not evaluate model improvement by being told the labels. It just uses the labels to calculate the validation loss.
Pretraining
It's important to note that when pretraining BERT for classification, it's not told what the answers are. Instead, it's fed token sequences where some are masked and it tries to guess what the missing tokens can be. This is the accuracy metric it spits out every so often.
Finetuning
When BERT is training, you'll see Training and Validation loss. Of the two, validation is the most important. Training loss is telling you how well the model deals with data it has already seen. The model is not trained on the evaluation set so Validation indicates how well it will fare with data it has not been trained on. Consequently, a falling training loss and a constant validation loss indicates that the model is overfitting. This can be addressed with dropout and/or weight_decay.
The conventional way of training BERT is to use BertForMaskedLM then use BertForSequenceClassification for fine tuning. The first takes sentences from the training data, randomly blanks out some words and guesses what they might be. However, we found that for small data sets that are not natural language (sequences of medical codes), we could get good performance by just using BertForSequenceClassification which does all the back propagation in a fraction of the time.
When calling BertForSequenceClassification.from_pretrained you might see:
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ... and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Apparently, this is OK. If you've been given a model that you wish to then use for classification. The classifier.* is the last layer of the neural net that takes a model that's been trained (possibly by masking but not classification) to which we will add a classification layer.
The special CLS token represents a summary of the string that's being encoded. The bert.pooler.* is a single, final, dense layer applied to this token that's used in classification (is not used at all in masking).
The Tokenizer(s)
Typically, you'll see tokenizers from both the Tokenizers and Tranformers libraries. They're both from HuggingFace so it's not surprising that you can tokenize with one and save it and then load the data with a tokenizer from the other.
Note that the tokenizers do not save the words as vectors. This embedding is done in the neural net. Instead, they offer different approaches - for instance, balancing a vocabulary that can cover words it's not seen yet with having a vocab that's too big and makes the training inefficient.
In the Transfomers library, you'll see some tokenizer names appended with Fast. These really are much faster to run as they're Rust implementations.
Almost as an aside, we tried to build our own tokenizer vocabulary by hand as we wanted to stuck as closely to the methodology used by our American colleagues as we could. This is perfectly possible (just point tokenizers.models.WordPiece.from_file at your vocab.txt file) but we found when we prepended all the codes with their classification system in both vocab.txt and the actual data we were tokenizing, accuracy was an unrealistically high 0.999 from the get-go. Why this was the case remains a mystery. Needless to say, we backed out that change. [Addendum: it's still a mystery why but using just the 30k most common whole words as tokens rather than >100k previously resulted in good accuracy].
PyTorch
Backing the Transformer's implementation of BERT if PyTorch. Browsing through the code you'll see references to unsqueeze. This is an interesting fella that can best be described by this StackOverflow answer. Basically, it can take a 2-d matrix and turn it into a 3-d tensor with the argument dictating whether the elements are in the x-, y- or z-plane. From a programming point of view, you'll see that the matrice's elements become nested one layer deeper in more square brackets, the exact configuration of those brackets depend upon which plane the elements have been unsqueezed into.
This diagram from the SO answer is great: