LSTM using the iris dataset

Continuing with the LSTM architecture for RNN introduced in Chapter 6, Recurrent and Convolutional Neural Networks, we present the iris dataset processing using the mxnet LSTM function. The function expects all inputs and outputs as numeric. It is particularly useful for processing text sequences, but here we will train an LSTM model on the iris dataset. The input values are petal.length, petal.width, sepal.length, and sepal.width. The output variable is Species, which is converted to a numeric value between one and three. The iris dataset has been detailed in Chapter 4, Perceptron Neural Network Modeling – Basic Models:

#################################################################
### Chapter 7 - Neural Networks with R - Use cases #########
### Prediction using LSTM on IRIS dataset #########
#################################################################

##Required one time
library("mxnet")

data(iris)

x = iris[1:5!=5,-5]
y = as.integer(iris$Species)[1:5!=5]

train.x = data.matrix(x)
train.y = y

test.x = data.matrix(iris[1:5==5,-5])
test.y = as.integer(iris$Species)[1:5==5]

model <- mx.mlp(train.x, train.y, hidden_node=10, out_node=3, out_activation="softmax",
num.round=20, array.batch.size=15, learning.rate=0.07, momentum=0.9,
eval.metric=mx.metric.accuracy)

preds = predict(model, test.x)
pred.label = max.col(t(preds))

test.y
pred.label
#################################################################

The program requires mxnet, which needs to be installed. mxnet for R is available for both CPUs and GPUs and for the following OSes: Linux, macOS, and Windows.

We will only indicate the installation procedures for Windows machines and CPU versions. Refer to the following URL for information on installation procedures for other architectures: https://mxnet.incubator.apache.org/get_started/install.html.

To install mxnet on a computer with a CPU processor, we use the prebuilt binary package. We can install the package directly on the R console through the following code:

cran <- getOption("repos")
cran["dmlc"] <- "https://s3-us-west-2.amazonaws.com/apache-mxnet/R/CRAN/"
options(repos = cran)
install.packages("mxnet")

The following packages are installed:

package ‘bindr’ successfully unpacked and MD5 sums checked
package ‘brew’ successfully unpacked and MD5 sums checked
package ‘assertthat’ successfully unpacked and MD5 sums checked
package ‘bindrcpp’ successfully unpacked and MD5 sums checked
package ‘glue’ successfully unpacked and MD5 sums checked
package ‘pkgconfig’ successfully unpacked and MD5 sums checked
package ‘BH’ successfully unpacked and MD5 sums checked
package ‘plogr’ successfully unpacked and MD5 sums checked
package ‘yaml’ successfully unpacked and MD5 sums checked
package ‘irlba’ successfully unpacked and MD5 sums checked
package ‘hms’ successfully unpacked and MD5 sums checked
package ‘XML’ successfully unpacked and MD5 sums checked
package ‘Rook’ successfully unpacked and MD5 sums checked
package ‘tidyselect’ successfully unpacked and MD5 sums checked
package ‘gridExtra’ successfully unpacked and MD5 sums checked
package ‘dplyr’ successfully unpacked and MD5 sums checked
package ‘downloader’ successfully unpacked and MD5 sums checked
package ‘htmltools’ successfully unpacked and MD5 sums checked
package ‘htmlwidgets’ successfully unpacked and MD5 sums checked
package ‘igraph’ successfully unpacked and MD5 sums checked
package ‘influenceR’ successfully unpacked and MD5 sums checked
package ‘purrr’ successfully unpacked and MD5 sums checked
package ‘readr’ successfully unpacked and MD5 sums checked
package ‘rstudioapi’ successfully unpacked and MD5 sums checked
package ‘rgexf’ successfully unpacked and MD5 sums checked
package ‘tidyr’ successfully unpacked and MD5 sums checked
package ‘viridis’ successfully unpacked and MD5 sums checked
package ‘DiagrammeR’ successfully unpacked and MD5 sums checked
package ‘visNetwork’ successfully unpacked and MD5 sums checked
package ‘data.table’ successfully unpacked and MD5 sums checked
package ‘mxnet’ successfully unpacked and MD5 sums checked

As you can see the installation of the mxnet package, install in addition to several packages. So, we already have everything we need to proceed. This mxnet library contains the mx.lstm function we are going to use:

library("mxnet")

In the following code, the internal dataset iris is loaded and the x and y variables are set with independent and target variables, respectively. The Species variable is converted to a number between one and three:

data(iris)
x = iris[1:5!=5,-5]
y = as.integer(iris$Species)[1:5!=5]

Just an explanation, with the following code:

x = iris[1:5!=5,-5]

We asked R to select from the iris dataset, which consists of 150 lines and five columns, only lines one to four, leaving out the fifth. This procedure will also be performed for multiples of five, so in the end, we will omit every multiple row of five from our selection. We will also omit the fifth column. At the end, we will get 120 rows and four columns.

We now set the input and output:

train.x = data.matrix(x)
train.y = y

Then we set the dataframe we will use for the test, by selecting only the lines we had previously omitted:

test.x = data.matrix(iris[1:5==5,-5])
test.y = as.integer(iris$Species)[1:5==5]

The mx.lstm function is called with the input and output values so that the model is trained with the LSTM on the RNN with the dataset:

model <- mx.mlp(train.x, train.y, hidden_node=10, out_node=3, out_activation="softmax",
num.round=20, array.batch.size=15, learning.rate=0.07, momentum=0.9,
eval.metric=mx.metric.accuracy)

Now we can make predictions:

preds = predict(model, test.x)
pred.label = max.col(t(preds))

Finally, we print the results to compare the model performance:

test.y
pred.label

Here are the results:

> test.y
[1] 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3
> pred.label
[1] 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3

From the comparison between the test data and those obtained from the forecast it can be noticed that the best results were obtained for the versicolor species. From the results obtained, it is clear that the model needs to be improved because the forecasts it is able to perform are not at the level of those obtained in the models we obtained in the previous examples.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset