- Start the Spark shell:
$ spark-shell
- Perform the required imports:
scala> import org.apache.spark.ml.classification.{GBTClassificationModel,
GBTClassifier}
scala> import
org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- Load and parse the data:
scala> val data =
spark.read.format("libsvm").load("s3a://sparkcookbook/patientdata")
- Split the data into training and test datasets:
scala> val Array(training, test) = data.randomSplit(Array(0.7, 0.3))
- Create a classification as a boosting strategy and set the number of iterations to 3:
scala> val gbt = new GBTClassifier().setMaxIter(10)
- Train the model:
scala> val model = gbt.fit(training)
- Evaluate the model on the test instances and compute the test error:
scala> val predictions = model.transform(test)
scala> val evaluator = new
MulticlassClassificationEvaluator().setMetricName("accuracy")
scala> val accuracy = evaluator.evaluate(predictions)
In this case, the accuracy of the model is 75 percent, which is almost the same as what we got for a random forest.