Unfortunately, the functionality for evaluating model quality in the pipeline API remains limited, as of version 1.5.2. Logistic regression does output a summary containing several evaluation metrics (available through the summary
attribute on the trained model), but these are calculated on the training set. In general, we want to evaluate the performance of the model both on the training set and on a separate test set. We will therefore dive down to the underlying MLlib layer to access evaluation metrics.
MLlib provides a module, org.apache.spark.mllib.evaluation
, with a set of classes for assessing the quality of a model. We will use the BinaryClassificationMetrics
class here, since spam classification is a binary classification problem. Other evaluation classes provide metrics for multi-class models, regression models and ranking models.
As in the previous section, we will illustrate the concepts in the shell, but you will find analogous code in the ROC.scala
script in the code examples for this chapter. We will use breeze-viz to plot curves, so, when starting the shell, we must ensure that the relevant libraries are on the classpath. We will use SBT assembly, as described in Chapter 10, Distributed Batch Processing with Spark (specifically, the Building and running standalone programs section), to create a JAR with the required dependencies. We will then pass this JAR to the Spark shell, allowing us to import breeze-viz. Let's write a build.sbt
file that declares a dependency on breeze-viz:
// build.sbt name := "spam_filter" scalaVersion := "2.10.5" libraryDependencies ++= Seq( "org.apache.spark" %% "spark-core" % "1.5.2" % "provided", "org.apache.spark" %% "spark-mllib" % "1.5.2" % "provided", "org.scalanlp" %% "breeze" % "0.11.2", "org.scalanlp" %% "breeze-viz" % "0.11.2", "org.scalanlp" %% "breeze-natives" % "0.11.2" )
Package the dependencies into a jar with:
$ sbt assembly
This will create a jar called spam_filter-assembly-0.1-SNAPSHOT.jar
in the target/scala-2.10
/ directory. To include this jar in the Spark shell, re-start the shell with the --jars
command line argument:
$ spark-shell --jars=target/scala-2.10/spam_filter-assembly-0.1-SNAPSHOT.jar
To verify that the packaging worked correctly, try to import breeze.plot
:
scala> import breeze.plot._ import breeze.plot._
Let's load the test set, with predictions, which we created in the previous section and saved as a parquet
file:
scala> val testDFWithPredictions = sqlContext.read.parquet( "transformedTest.parquet") testDFWithPredictions: org.apache.spark.sql.DataFrame = [fileName: string, label: double, prediction: double, probability: vector]
The BinaryClassificationMetrics
object expects an RDD[(Double, Double)]
object of pairs of scores (the probability assigned by the classifier that a particular e-mail is spam) and labels (whether an e-mail is actually spam). We can extract this RDD from our DataFrame:
scala> import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector scala> import org.apache.spark.sql.Row import org.apache.spark.sql.Row scala> val scoresLabels = testDFWithPredictions.select( "probability", "label").map { case Row(probability:Vector, label:Double) => (probability(1), label) } org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[3] at map at <console>:23 scala> scoresLabels.take(5).foreach(println) (0.9999999967713409,1.0) (0.9999983827108793,1.0) (0.9982059900606365,1.0) (0.9999790713978142,1.0) (0.9999999999999272,1.0)
We can now construct the BinaryClassificationMetrics
instance:
scala> import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import mllib.evaluation.BinaryClassificationMetrics scala> val bm = new BinaryClassificationMetrics(scoresLabels) bm: BinaryClassificationMetrics = mllib.evaluation.BinaryClassificationMetrics@254ed9ba
The BinaryClassificationMetrics
objects contain many useful metrics for evaluating the performance of a classification model. We will look at the receiver operating
characteristic (ROC) curve.
ROC Curves
Imagine gradually decreasing, from 1.0, the probability threshold at which we assume a particular e-mail is spam. Clearly, when the threshold is set to 1.0, no e-mails will get classified as spam. This means that there will be no false positives (ham messages which we incorrectly classify as spam), but it also means that there will be no true positives (spam messages that we correctly identify as spam): all spam e-mails will be incorrectly identified as ham.
As we gradually lower the probability threshold at which we assume a particular e-mail is spam, our spam filter will, hopefully, start identifying a large fraction of e-mails as spam. The vast majority of these will, if our algorithm is well-designed, be real spam. Thus, our rate of true positives increases. As we gradually lower the threshold, we start classifying messages about which we are less sure of as spam. This will increase the number of messages correctly identified as spam, but it will also increase the number of false positives.
The ROC curve plots, for each threshold value, the fraction of true positives against the fraction of false positives. In the best case, the curve is always 1: this happens when all spam messages are given a score of 1.0, and all ham messages are given a score of 0.0. By contrast, the worst case happens when the curve is a diagonal P(true positive) = P(false positive), which occurs when our algorithm does no better than random. In general, ROC curves fall somewhere in between, forming a convex shell above the diagonal. The deeper this shell, the better our algorithm.
(left) ROC curve for a model performing much better than random: the curve reaches very high true positive rates for a low false positive rate.
(middle) ROC curve for a model performing significantly better than random.
(right) ROC curve for a model performing only marginally better than random: the true positive rate is only marginally larger than the rate of false positives, for any given threshold, meaning that nearly half the examples are misclassified.
We can calculate an array of points on the ROC curve using the .roc
method on our BinaryClassificationMetrics
instance. This returns an RDD[(Double, Double)]
of (false positive, true positive) fractions for each threshold value. We can collect this as an array:
scala> val rocArray = bm.roc.collect rocArray: Array[(Double, Double)] = Array((0.0,0.0), (0.0,0.16793893129770993), ...
Of course, an array of numbers is not very enlightening, so let's plot the ROC curve with breeze-viz. We start by transforming our array of pairs into two arrays, one of false positives and one of true positives:
scala> val falsePositives = rocArray.map { _._1 } falsePositives: Array[Double] = Array(0.0, 0.0, 0.0, 0.0, 0.0, ... scala> val truePositives = rocArray.map { _._2 } truePositives: Array[Double] = Array(0.0, 0.16793893129770993, 0.19083969465...
Let's plot these two arrays:
scala> import breeze.plot._ import breeze.plot. scala> val f = Figure() f: breeze.plot.Figure = breeze.plot.Figure@3aa746cd scala> val p = f.subplot(0) p: breeze.plot.Plot = breeze.plot.Plot@5ed1438a scala> p += plot(falsePositives, truePositives) p += plot(falsePositives, truePositives) scala> p.xlabel = "false positives" p.xlabel: String = false positives scala> p.ylabel = "true positives" p.ylabel: String = true positives scala> p.title = "ROC" p.title: String = ROC scala> f.refresh
The ROC curve hits 1.0 for a small value of x: that is, we retrieve all true positives at the cost of relatively few false positives. To visualize the curve more accurately, it is instructive to limit the range on the x-axis from 0 to 0.1.
scala> p.xlim = (0.0, 0.1) p.xlim: (Double, Double) = (0.0,0.1)
We also need to tell breeze-viz to use appropriate tick spacing, which requires going down to the JFreeChart layer underlying breeze-viz:
scala> import org.jfree.chart.axis.NumberTickUnit import org.jfree.chart.axis.NumberTickUnit scala> p.xaxis.setTickUnit(new NumberTickUnit(0.01)) scala> p.yaxis.setTickUnit(new NumberTickUnit(0.1))
We can now save the graph:
scala> f.saveas("roc.png")
This produces the following graph, stored in roc.png
:
By looking at the graph, we see that we can filter out 85% of spam without a single false positive. Of course, we would need a larger test set to really validate this assumption.
A graph is useful to really understand the behavior of a model. Sometimes, however, we just want to have a single measure of the quality of a model. The area under the ROC curve can be a good such metric:
scala> bm.areaUnderROC res21: Double = 0.9983061235861147
This can be interpreted as follows: given any two messages randomly drawn from the test set, one of which is ham, and one of which is spam, there is a 99.8% probability that the model assigned a greater likelihood of spam to the spam message than to the ham message.
Other useful measures of model quality are the precision and recall for particular thresholds, or the F1 score. All of these are provided by the BinaryClassificationMetrics
instance. The API documentation lists the methods available: https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.