Monday, April 9, 2018

Cross validation in Spark

What it is

"k-fold cross-validation relies on quarantining subsets of the training data during the learning process... k-fold CV begins by randomly splitting the data into k disjoint subsets, called folds (typical choices for k are 5, 10, or 20). For each fold, a model is trained on all the data except the data from that fold and is subsequently used to generate predictions for the data from that fold.  After all k-folds are cycled through, the predictions for each fold are aggregated and compared to the true target variable to assess accuracy" [1]


Spark's  "CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3 folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing" (from the documentation).


val nb        = new NaiveBayes("nb")
val pipeline  = new Pipeline().setStages(Array(tokenizer, remover, ngram, hashingTF, idf, nb))
val evaluator = new MulticlassClassificationEvaluator().setLabelCol(LABEL).setPredictionCol("prediciton").setMetricName("accuracy")
val paramGrid = new ParamGridBuilder().addGrid(nb.smoothing, Array(100.0, 10.0, 1.0, 0.1, 0.01, 0.001)).addGrid(idf.minDocFreq, Array(1, 2, 4, 8, 16, 32)).build()
val cv        = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5)
val fitted    =
val metrics   = fitted.avgMetrics

where tokenizer, remover, ngram, hashingTF and idf are instances of Spark's Tokenizer, StopWordRemover, NGram, HashingTF and IDF .

Running this on the Subject text of the 20 Newsgroup data set yielded the optimized hyperparameters of 1 document for a word to be significant and a smoothing value of 0.1 for regularization leading to 77.1% accuracy.

Running this on all the text of the 20 Newsgroup data yielded values of 10.0 for smoothing and 4 for minDocFreq giving an optimized accuracy of nearly 88%.

Those results in tabular form:

subject only0.1177.1%
all text10.0487.9%

Interestingly, the range over the results for all smoothing hyperparameters was typically less than 6% but the range of results over all minDocFreq was as much as 60%. For this data and this model at least the rather unexceptional conclusion is that you can increase accuracy more from improving feature engineering than model tuning.

(Note: NGram.n was set to 2. After some more CV, I found it was best leaving it as 1. Then, the "subject only" accuracy was 85.1% and the "all text" accuracy was 89.4%).


Happily, Spark has parallel cross validation as of 2.3.0. See TrainValidationSplit.setParallelism(...) - it has a @Since("2.3.0"). This should improve performance. Using 10 executors with 30gb of memory and 2 cores each, CV on the full data set could take 20 minutes or so.


For logs on what TrainValidationSplit is doing, run:

scala> sc.setLogLevel("DEBUG")

This can be irritating so change it back to ERROR when you're done.

[1] Real World Machine Learning (sample here).

No comments:

Post a Comment