LSTM for image processing

Let's imagine we want to perform handwriting recognition. From time to time, we get a new column of data. Is it the end of a letter? If yes, which one? Is it the end of a word? Is it punctuation? All these questions can be answered with a recurrent network.

For our test example, we will go back to our 10-digit dataset and use LSTMs instead of convolution layers.

We use similar hyperparameters:

import tensorflow as tf
from tensorflow.contrib import rnn

# rows of 28 pixels
# unrolled through 28 time steps (our images are (28,28))

# hidden LSTM units

# learning rate for adam

n_epochs = 10
step = 100

Setting up training and testing data is almost similar to our CNN example, except for the way we reshape the images:

import os
import numpy as np

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
mnist = fetch_mldata('MNIST original') =
[-1, time_steps, n_input]) / 255.
mnist.num_examples = len(
mnist.labels =

X_train, X_test, y_train, y_test = train_test_split(, mnist.labels, test_size=(1. / 7.))

Let's quickly set up our network and its scaffolding:

x = tf.placeholder(tf.float32, [None,time_steps, n_input])
y = tf.placeholder(tf.int64, [None])

# processing the input tensor from [batch_size, n_steps,n_input]
# to "time_steps" number of [batch_size, n_input] tensors
input = tf.unstack(x, time_steps,1)

lstm_layer = rnn.BasicLSTMCell(num_units, forget_bias=True)
outputs, _ = rnn.static_rnn(lstm_layer, input,dtype=tf.float32)

prediction = tf.layers.dense(inputs=outputs[-1], units = n_classes)

loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=prediction, labels=y))
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(prediction,1), y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

We are now ready to train:

with tf.Session() as sess:
for epoch in range(n_epochs):
permut = np.random.permutation(len(X_train))
print("epoch: %i" % epoch)
for j in range(0, len(X_train), batch_size):
if j % step == 0:
print(" batch: %i" % j)

batch = permut[j:j+batch_size]
Xs = X_train[batch]
Ys = y_train[batch], feed_dict={x: Xs, y: Ys})

if j % step == 0:,feed_dict={x:Xs,y:Ys}),feed_dict={x:Xs,y:Ys})
print(" accuracy %f" % acc)
print(" loss %f" % los)
epoch: 0
batch: 0
accuracy 0.195312
loss 2.275624

batch: 3200
accuracy 0.484375
loss 1.514501

batch: 54400
accuracy 0.992188
loss 0.022468

batch: 57600
accuracy 1.000000
loss 0.007411

We get quite high accuracy here as well, but we will leave it to the reader to check the accuracy over the test samples.

