- Start the Spark shell:
$ spark-shell
- Perform the required imports:
scala> import org.apache.spark.ml.classification.NaiveBayes
scala> import
org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- Load the data into the DataFrame from S3:
scala> val data =
spark.read.format("libsvm").load("s3a://sparkcookbook/patientdata")
- Split the data into training and test datasets:
scala> val Array(trainingData, testData) =
data.randomSplit(Array(0.7, 0.3))
- Train the model with the training dataset:
scala> val model = new NaiveBayes().fit(trainingData)
- Do the prediction:
scala> val predictions = model.transform(testData)
- Evaluate the accuracy:
scala> val evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy")
scala> val accuracy = evaluator.evaluate(predictions)
Here the accuracy is only 55 percent, which shows Naive Bayes is not the best algorithm for this dataset.