Thursday, December 21, 2017

Decision Trees and Spark

What they are

"We start at the root and split the data on the feature that results in the largest information gain... In an iterative process, we can then repeat this splitting procedure at each child node until the leaves are pure. This means that the samples at each node all belong to the same class. In practice, this can result in a very deep tree with many nodes, which can easily lead to overfitting. Thus, we typically prune the tree by setting a limit for the maximal depth of the tree." [1]

Why they're important

Typically when approaching a problem, a data scientist will first start with a decision tree.

"As for learning algorithms Logistic Regression and Decision Trees / Random Forests are standard first-round approaches that are fairly interpretable (less so for RF) and do well on many problems" (Quora). This can be especially important when legislation (like the forthcoming GDPR) gives consumers the right to know why they were or were not targeted for a product.

Why they work

While looking at Spark's decision tree algorithm you see that each split is trying to increase the entropy. There are other formulas for information gain than entropy but you get the idea.

Spark's Implementation

Spark has a nice example here but it loads example data in LIBSVM format. What is this? Well, it's basically T- > Vector where T is our label type. Instead of using the example LIBSVM data, let's take my data set of bank accounts and write:

val data = =>(asDouble(r.histo), new DenseVector(Array(asDouble(r.acc), asDouble(, asDouble(r.number), asDouble(r.value)))))

where ds is of type Dataset[Record] and Record is given by:

case class Record(acc: String, histo: String, name: String, value: Double, number: Int)

and asDouble is just a function that turns anything into a double. It is just test code after all.

Now it's a in a form I can feed into Spark using the code here that gives us a Model by fitting the Pipeline to the training data. (The pipeline is the label indexer, feature indexer, decision tree classifier and label converter in that order).

I extract my model from the pipeline so

scala> val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]

and I can actually see what the tree looks like in a very human-readable form:

scala> treeModel.toDebugString
res40: String =
"DecisionTreeClassificationModel (uid=dtc_f122057dc098) of depth 5 with 49 nodes
  If (feature 1 <= -1.940284966E9)
   If (feature 1 <= -2.016375591E9)
    Predict: 0.0
   Else (feature 1 > -2.016375591E9)
    If (feature 2 <= 50.0)
     If (feature 1 <= -2.004972479E9)
      If (feature 3 <= 5.8760448E7)
       Predict: 1.0

[Note that ml.classification.DecisionTreeClassifier is DataFrame oriented while the older mllib.tree.DecisionTree is RDD-based]

Now, the Model can transform the test data and we can see how effective our tree is:

scala> val transformed = model.transform(testData).cache()
|            _1|                  _2|indexedLabel|     indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
|-1.477067101E9|[-2.146512264E9,2...|         3.0|[-2.146512264E9,2...|[326.0,0.0,0.0,27...|[0.00116486398605...|       3.0|-1.477067101E9|
|-1.477067101E9|[-2.14139878E9,20...|         3.0|[-2.14139878E9,20...|[326.0,0.0,0.0,27...|[0.00116486398605...|       3.0|-1.477067101E9|
|-1.477067101E9|[-2.13437752E9,20...|         3.0|[-2.13437752E9,20...|[326.0,0.0,0.0,27...|[0.00116486398605...|       3.0|-1.477067101E9|

where a typical row of this DataFrame looks like:

res46: org.apache.spark.sql.Row = [-1.477067101E9,[-2.146512264E9,2099.0,1569.0,48563.0],3.0,[-2.146512264E9,2099.0,1569.0,48563.0],[326.0,0.0,0.0,279535.0],[0.0011648639860502177,0.0,0.0,0.9988351360139498],3.0,-1.477067101E9]

The documentation is useful in interpreting this (see here) but suffice to say that the rawPrediciton and probability is the number of training instances/probability of the datapoint falling into a given class (see StackExchange).

Note that:

scala> ds.count()
res1: Long = 3680675
scala> transformed.count()
res1: Long = 1194488

So, we have a tree with the number of leaves being about one third of the actual data. Looks like we're overfitting (see [1]) but, again, this is just me playing.

Finally, we can use a MulticlassClassificationEvaluator to find the overall accuracy of our model.

scala> val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel")
scala> val accuracy = evaluator.evaluate(transformed)
scala> println("Test Error = " + (1.0 - accuracy))
Test Error = 0.1842248695787883

Hmm, pretty bad but we were just playing.

[1] Python Machine Learning

No comments:

Post a Comment