In this section, we will learn how to have several plots in the same figure.
The key new method that allows multiple plots in the same figure is fig.subplot(nrows, ncols, plotIndex)
. This method, an overloaded version of the fig.subplot
method we have been using up to now, both sets the number of rows and columns in the figure and returns a specific subplot. It takes three arguments:
nrows
: The number of rows of subplots in the figurencols
: The number of columns of subplots in the figureplotIndex
: The index of the plot to returnUsers familiar with MATLAB or matplotlib will note that the .subplot
method is identical to the eponymous methods in these frameworks. This might seem a little complex, so let's look at an example (you will find the code for this in BreezeDemo.scala
):
import breeze.plot._ def subplotExample { val data = HWData.load val fig = new Figure("Subplot example") // upper subplot: plot index '0' refers to the first plot var plt = fig.subplot(2, 1, 0) plt += plot(data.heights, data.weights, '.') // lower subplot: plot index '1' refers to the second plot plt = fig.subplot(2, 1, 1) plt += plot(data.heights, data.reportedHeights, '.', colorcode="black") fig.refresh }
Running this example produces the following plot:
Now that we have a basic grasp of how to add several subplots to the same figure, let's do something a little more interesting. We will write a class to draw scatterplot matrices. These are useful for exploring correlations between different features.
If you are not familiar with scatterplot matrices, have a look at the figure at the end of this section for an idea of what we are constructing. The idea is to build a square matrix of scatter plots for each pair of features. Element (i, j) in the matrix is a scatter plot of feature i against feature j. Since a scatter plot of a variable against itself is of limited use, one normally draws histograms of each feature along the diagonal. Finally, since a scatter plot of feature i against feature j contains the same information as a scatter plot of feature j against feature i, one normally only plots the upper triangle or the lower triangle of the matrix.
Let's start by writing functions for the individual plots. These will take a Plot
object referencing the correct subplot and vectors of the data to plot:
import breeze.plot._ import breeze.linalg._ class ScatterplotMatrix(val fig:Figure) { /** Draw the histograms on the diagonal */ private def plotHistogram(plt:Plot)( data:DenseVector[Double], label:String) { plt += hist(data) plt.xlabel = label } /** Draw the off-diagonal scatter plots */ private def plotScatter(plt:Plot)( xdata:DenseVector[Double], ydata:DenseVector[Double], xlabel:String, ylabel:String) { plt += plot(xdata, ydata, '.') plt.xlabel = xlabel plt.ylabel = ylabel } ...
Notice the use of hist(data)
to draw a histogram. The argument to hist
must be a vector of data points. The hist
method will bin these and represent them as a histogram.
Now that we have the machinery for drawing individual plots, we just need to wire everything together. The tricky part is to know how to select the correct subplot for a given row and column position in the matrix. We can select a single plot by calling fig.subplot(nrows, ncolumns, plotIndex)
, but translating from a (row, column) index pair to a single plotIndex
is not obvious. The plots are numbered in increasing order, first from left to right, then from top to bottom:
0 1 2 3 4 5 6 7 ...
Let's write a short function to select a plot at a (row, column) index pair:
private def selectPlot(ncols:Int)(irow:Int, icol:Int):Plot = { fig.subplot(ncols, ncols, (irow)*ncols + icol) }
We are now in a position to draw the matrix plot itself:
/** Draw a scatterplot matrix. * * This function draws a scatterplot matrix of the correlation * between each pair of columns in `featureMatrix`. * * @param featureMatrix A matrix of features, with each column * representing a feature. * @param labels Names of the features. */ def plotFeatures(featureMatrix:DenseMatrix[Double], labels:List[String]) { val ncols = featureMatrix.cols require(ncols == labels.size, "Number of columns in feature matrix "+ "must match length of labels" ) fig.clear fig.subplot(ncols, ncols, 0) (0 until ncols) foreach { irow => val p = selectPlot(ncols)(irow, irow) plotHistogram(p)(featureMatrix(::, irow), labels(irow)) (0 until irow) foreach { icol => val p = selectPlot(ncols)(irow, icol) plotScatter(p)( featureMatrix(::, irow), featureMatrix(::, icol), labels(irow), labels(icol) ) } } } }
Let's write an example for our class. We will use the height-weight data again:
import breeze.linalg._ import breeze.numerics._ import breeze.plot._ object ScatterplotMatrixDemo extends App { val data = HWData.load val m = new ScatterplotMatrix(Figure("Scatterplot matrix demo")) // Make a matrix with three columns: the height, weight and // reported weight data. val featureMatrix = DenseMatrix.horzcat( data.heights.toDenseMatrix.t, data.weights.toDenseMatrix.t, data.reportedWeights.toDenseMatrix.t ) m.plotFeatures(featureMatrix,List("height", "weight", "reportedWeights")) }