Monday, December 24, 2018

Unbalanced


... or imbalanced (apparently, imbalanced is the noun and unbalanced is the verb. Live and learn).

Anyway, in a data scientist job interview I was asked how I would build a network intrusion system using machine learning. I replied that I could build them a system that was 99% accurate. Gasps from my interviewers. I said, I could even give 99.9% accuracy. No! I could give 99.99% accuracy with ease.

I then admitted I could do this by saying that all traffic was fine (since much less than 0.01% was malicious), bill them and then head home. The issue here is imbalanced data.

This is a great summary of what to do in this situation. To recap:

1. If you have plenty of data, ignore a lot from the major class. If you have too little, duplicate the minor class.

2. Generate new data by "wiggling" the original (eg, the SMOTE algorithm).

3. "Decision trees often perform well on imbalanced datasets" [ibid].

4. Penalize mistakes in misclassifying the minor data.

For this last suggestion, you can use a meta algorithm and any model but in ANNs, you get it out of the box. In DL4J, I added something like this:

.lossFunction(new LossNegativeLogLikelihood(Nd4j.create(Array(0.01f, 1f)))

"Ideally, this drives the network away from the poor local optima of always predicting the most commonly ccuring class(es)." [1]

Did it improve my model? Well, first we have to abandon accuracy as our only metric. We introduce two new chaps, precision and recall.

Precision is the ratio of true positives to all the samples my model declared were positive irrespective of being true or false ("what fraction of alerts were correct?"). Mnemonic: PreCision is the sum of the Positive Column. Precision = TP / (TP + FP).

Recall is the ratio of true positives to all the positives in the data ("what fraction of intrusions did we correctly detect?"). Mnemonic: REcall deals with all the REal positives. Recall = TP / (TP + FN).

(There is also the F1 score which "is the 2*precision*recall/(precision+recall). It is also called the F Score or the F Measure. Put another way, the F1 score conveys the balance between the precision and the recall" - Dr Jason Brownlee. This equation comes from the harmonic mean of precision and recall).

What is a desirable results is very use case dependent. But with weights 0.005f, 1f, I got:

Accuracy:  0.9760666666666666
Precision: 0.3192534381139489
Recall:    0.9285714285714286
.
.

=========================Confusion Matrix=========================
     0     1
-------------
 28957   693 | 0 = 0
    25   325 | 1 = 1

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

using faux data (6000 observations of a time series 50 data points long and a class imbalance of 100:1)

I got the job to build an intrusion detection system and these results would make the customer very happy - although I can't stress enough that this is a neural net trained on totally fake data. But if we could have 93% of intrusions detected with in only 1 in three investigations yielding a result, the client would be overjoyed. Currently, the false positives are causing them to turn the squelch in their current system too high and consequently they're not spotting real intrusions.

[1] Deep Learning: A Practitioner's Approach


No comments:

Post a Comment