What it is
"k-fold cross-validation relies on quarantining subsets of the training data during the learning process... k-fold CV begins by randomly splitting the data into k disjoint subsets, called folds (typical choices for k are 5, 10, or 20). For each fold, a model is trained on all the data except the data from that fold and is subsequently used to generate predictions for the data from that fold. After all k-folds are cycled through, the predictions for each fold are aggregated and compared to the true target variable to assess accuracy" [1]
Spark
Spark's "CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3 folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing" (from the documentation).
import org.apache.spark.ml.classification.NaiveBayes
val nb = new NaiveBayes("nb")
val pipeline = new Pipeline().setStages(Array(tokenizer, remover, ngram, hashingTF, idf, nb))
val evaluator = new MulticlassClassificationEvaluator().setLabelCol(LABEL).setPredictionCol("prediciton").setMetricName("accuracy")
val paramGrid = new ParamGridBuilder().addGrid(nb.smoothing, Array(100.0, 10.0, 1.0, 0.1, 0.01, 0.001)).addGrid(idf.minDocFreq, Array(1, 2, 4, 8, 16, 32)).build()
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5)
val fitted = cv.fit(df)
val metrics = fitted.avgMetrics
val nb = new NaiveBayes("nb")
val pipeline = new Pipeline().setStages(Array(tokenizer, remover, ngram, hashingTF, idf, nb))
val evaluator = new MulticlassClassificationEvaluator().setLabelCol(LABEL).setPredictionCol("prediciton").setMetricName("accuracy")
val paramGrid = new ParamGridBuilder().addGrid(nb.smoothing, Array(100.0, 10.0, 1.0, 0.1, 0.01, 0.001)).addGrid(idf.minDocFreq, Array(1, 2, 4, 8, 16, 32)).build()
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5)
val fitted = cv.fit(df)
val metrics = fitted.avgMetrics
where tokenizer, remover, ngram, hashingTF and idf are instances of Spark's Tokenizer, StopWordRemover, NGram, HashingTF and IDF .
Running this on the Subject text of the 20 Newsgroup data set yielded the optimized hyperparameters of 1 document for a word to be significant and a smoothing value of 0.1 for regularization leading to 77.1% accuracy.
Running this on all the text of the 20 Newsgroup data yielded values of 10.0 for smoothing and 4 for minDocFreq giving an optimized accuracy of nearly 88%.
Those results in tabular form:
Interestingly, the range over the results for all smoothing hyperparameters was typically less than 6% but the range of results over all minDocFreq was as much as 60%. For this data and this model at least the rather unexceptional conclusion is that you can increase accuracy more from improving feature engineering than model tuning.
(Note: NGram.n was set to 2. After some more CV, I found it was best leaving it as 1. Then, the "subject only" accuracy was 85.1% and the "all text" accuracy was 89.4%).
EfficiencyRunning this on the Subject text of the 20 Newsgroup data set yielded the optimized hyperparameters of 1 document for a word to be significant and a smoothing value of 0.1 for regularization leading to 77.1% accuracy.
Running this on all the text of the 20 Newsgroup data yielded values of 10.0 for smoothing and 4 for minDocFreq giving an optimized accuracy of nearly 88%.
Those results in tabular form:
Corpus | smoothing | minDocFreq | Accuracy |
subject only | 0.1 | 1 | 77.1% |
all text | 10.0 | 4 | 87.9% |
Interestingly, the range over the results for all smoothing hyperparameters was typically less than 6% but the range of results over all minDocFreq was as much as 60%. For this data and this model at least the rather unexceptional conclusion is that you can increase accuracy more from improving feature engineering than model tuning.
(Note: NGram.n was set to 2. After some more CV, I found it was best leaving it as 1. Then, the "subject only" accuracy was 85.1% and the "all text" accuracy was 89.4%).
Happily, Spark has parallel cross validation as of 2.3.0. See TrainValidationSplit.setParallelism(...) - it has a @Since("2.3.0"). This should improve performance. Using 10 executors with 30gb of memory and 2 cores each, CV on the full data set could take 20 minutes or so.
Logging
For logs on what TrainValidationSplit is doing, run:
scala> sc.setLogLevel("DEBUG")
This can be irritating so change it back to ERROR when you're done.
[1] Real World Machine Learning (sample here).
No comments:
Post a Comment