Define the LSTM cell

Now, we define the function called LSTM_cell, which returns the cell state and hidden state as an output. Recall the steps we saw in the forward propagation of LSTM, which is implemented as shown in the following code. LSTM_cell takes the input, previous hidden state, and previous cell state as inputs, and returns the current cell state and current hidden state as outputs:

def LSTM_cell(input, prev_hidden_state, prev_cell_state):

it = tf.sigmoid(tf.matmul(input, U_i) + tf.matmul(prev_hidden_state, W_i) + b_i)

ft = tf.sigmoid(tf.matmul(input, U_f) + tf.matmul(prev_hidden_state, W_f) + b_f)

ot = tf.sigmoid(tf.matmul(input, U_o) + tf.matmul(prev_hidden_state, W_o) + b_o)

gt = tf.tanh(tf.matmul(input, U_g) + tf.matmul(prev_hidden_state, W_g) + b_g)

ct = (prev_cell_state * ft) + (it * gt)

ht = ot * tf.tanh(ct)

return ct, ht
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset