Step 6 - LSTM network construction

Now, let's set up an LSTM model with the preceding parameters and structure:

val model = LSTMNetworkConstructor.setupModel(nSteps, nInput, nHidden, nClasses, batchSize, ctx = ctx) 

In the preceding line, setupModel() is the method that does the trick. The getSymbol() method actually constructs the LSTM cell. We will see its signature, too, later on. It accepts sequence length, number of input, number of hidden layers, number of labels, batch size, number of LSTM layers, dropout MXNet context, and constructs an LSTM model of type using the case class LSTMModel:

case class LSTMModel(exec: Executor, symbol: Symbol, data: NDArray, label: NDArray, argsDict: Map[String,                     NDArray], gradDict: Map[String, NDArray]) 

Now here's the signature of the setupModel():

def setupModel(seqLen: Int, nInput: Int, numHidden: Int, numLabel: Int, batchSize: Int, numLstmLayer: Int = 1, dropout: Float = 0f, ctx: Context = Context.cpu()): LSTMModel = { 
//get the symbolic model 
    val sym = LSTMNetworkConstructor.getSymbol(seqLen, numHidden, numLabel, numLstmLayer = numLstmLayer) 
    val argNames = sym.listArguments() 
    val auxNames = sym.listAuxiliaryStates() 
// defining the initial argument and binding them to the model val initC = for (l <- 0 until numLstmLayer) yield (s"l${l}_init_c", (batchSize, numHidden)) val initH = for (l <- 0 until numLstmLayer) yield (s"l${l}_init_h", (batchSize, numHidden)) val initStates = (initC ++ initH).map(x => x._1 -> Shape(x._2._1, x._2._2)).toMap val dataShapes = Map("data" -> Shape(batchSize, seqLen, nInput)) ++ initStates val (argShapes, outShapes, auxShapes) = sym.inferShape(dataShapes) val initializer = new Uniform(0.1f) val argsDict = argNames.zip(argShapes).map { case (name, shape) => val nda = NDArray.zeros(shape, ctx) if (!dataShapes.contains(name) && name != "softmax_label") { initializer(name, nda) } name -> nda }.toMap val argsGradDict = argNames.zip(argShapes) .filter(x => x._1 != "softmax_label" && x._1 != "data") .map( x => x._1 -> NDArray.zeros(x._2, ctx) ).toMap val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap val exec = sym.bind(ctx, argsDict, argsGradDict, "write", auxDict, null, null) val data = argsDict("data") val label = argsDict("softmax_label")
LSTMModel(exec, sym, data, label, argsDict, argsGradDict)
}

In the preceding method, we obtained a symbolic model for the deep RNN using the getSymbol() method that can be seen as follows. I have provided detailed comments and believe that will be enough to understand the workflow of the code:

  private def getSymbol(seqLen: Int, numHidden: Int, numLabel: Int, numLstmLayer: Int = 1, 
dropout: Float = 0f): Symbol = { //symbolic training and label variables var inputX = Symbol.Variable("data") val inputY = Symbol.Variable("softmax_label") //the initial parameters and cells var paramCells = Array[LSTMParam]() var lastStates = Array[LSTMState]()
//numLstmLayer is 1 for (i <- 0 until numLstmLayer) { paramCells = paramCells :+ LSTMParam(i2hWeight =
Symbol.Variable(s"l${i}_i2h_weight"), i2hBias = Symbol.Variable(s"l${i}_i2h_bias"),
h2hWeight = Symbol.Variable(s"l${i}_h2h_weight"),
h2hBias = Symbol.Variable(s"l${i}_h2h_bias")) lastStates = lastStates :+ LSTMState(c =
Symbol.Variable(s"l${i}_init_c"),
h = Symbol.Variable(s"l${i}_init_h")) } assert(lastStates.length == numLstmLayer)
val lstmInputs = Symbol.SliceChannel()(inputX)(Map("axis"
> 1, "num_outputs" -> seqLen,
"squeeze_axis" -> 1)) var hiddenAll = Array[Symbol]() var dpRatio = 0f var hidden: Symbol = null //for each one of the 128 inputs, create a LSTM Cell for (seqIdx <- 0 until seqLen) { hidden = lstmInputs.get(seqIdx) // stack LSTM, where numLstmLayer is 1 so the loop will be executed only one time for (i <- 0 until numLstmLayer) { if (i == 0) dpRatio = 0f else dpRatio = dropout //for each one of the 128 inputs, create a LSTM Cell val nextState = lstmCell(numHidden, inData = hidden, prevState = lastStates(i), param = paramCells(i), seqIdx = seqIdx, layerIdx = i, dropout =
dpRatio) hidden = nextState.h // has no effect lastStates(i) = nextState // has no effect } // adding dropout before softmax has no effect- dropout is 0 due to numLstmLayer == 1 if (dropout > 0f) hidden = Symbol.Dropout()()(Map("data" -> hidden, "p" -> dropout)) // store the lstm cells output layers hiddenAll = hiddenAll :+ hidden
}

In summary, the algorithm uses 128 LSTM cells in parallel, and I concatenated all 128 cells and fed them to the output activation layer. Let's concatenate the cells, outputs:

val finalOut = hiddenAll.reduce(_+_) 

Then we connect them to an output layer that corresponds to the 6 label:

 val fc = Symbol.FullyConnected()()(Map("data" -> finalOut, "num_hidden" -> numLabel)) 
 //softmax activation against the label 
 Symbol.SoftmaxOutput()()(Map("data" -> fc, "label" -> inputY)) 

In the preceding code segment, LSTMState and LSTMParam are two case classes that used to define the state of each LSTM cell and the latter accepts the parameters needed to construct an LSTM cell. final case class LSTMState(c: Symbol, h: Symbol) final case class LSTMParam(i2hWeight: Symbol, i2hBias: Symbol, h2hWeight: Symbol, h2hBias: Symbol).

Now it's time to discuss the most important step, which is LSTM cell construction. We will use some diagrams and legends as shown in the following diagram:

Figure 11: Legends used to describe LSTM cell in the following

The repeating module in an LSTM contains four interacting layers as shown in the following figure:

Figure 12: Inside an LSTM cell, that is the, repeating module in an LSTM contains four interacting layers

An LSTM cell is defined by its stats and parameters, as defined by the preceding two case classes:

  • LSTM state: c is the cell stat (its memory knowledge) to be used during the training and h is the output
  • LSTM parameters: To be optimized by the training algorithm
  • i2hWeight: Input to hidden weight
  • i2hBias: Input to hidden bias
  • h2hWeight: Hidden to hidden weight
  • h2hBias: Hidden to hidden bias
  • i2h: An NN for input data
  • h2h: An NN from the previous h

In the code, the two fully connected layers have been created, concatenated, and transformed to four copies by the following code. Let's add a hidden layer of size numHidden * 4 (numHidden set to 28) that takes as input the inputdata:

val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa, "weight" ->                 param.i2hWeight, "bias" -> param.i2hBias, "num_hidden" -> numHidden * 4)) 

Then we add a hidden layer of size numHidden * 4 (numHidden set to 28) that takes as input the previous output of the cell:

val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,"weight" ->             param.h2hWeight,"bias" -> param.h2hBias,"num_hidden" -> numHidden * 4)) 

Now let's concatenate them together:

val gates = i2h + h2h 

Then let's make four copies of gates before we compute the gates:

val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(gates)(Map("num_outputs" -> 4)) 

Then we compute the gates:

val sliceGates = Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(gates)(Map("num_outputs" -> 4)) 

Now the activation for the forget gate is represented by the following code:

val forgetGate = Symbol.Activation()()(Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid")) 

We can see this in the following figure:

Figure 13: Forget gate in an LSTM cell

Now, the activation for the in gate and in transform are represented by the following code:

val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid"))   
val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh")) 

We can also see this in Figure 14:

Figure 14: In gate and transform gate in an LSTM cell

The next state is defined by the following code:

val nextC = (forgetGate * prevState.c) + (ingate * inTransform) 

The preceding code can be represented by the following figure too:

Figure 15: Next or transited gate in an LSTM cell

Finally, the output gate can be represented by the following code:

val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh")) 

The preceding code can be represented by the following figure too:

Figure 16: Output gate in an LSTM cell

Too much of a mouthful? No worries, here I have provided the full code for this method:

  // LSTM Cell symbol 
  private def lstmCell( numHidden: Int, inData: Symbol, prevState: LSTMState, param: LSTMParam, 
                        seqIdx: Int, layerIdx: Int, dropout: Float = 0f): LSTMState = { 
        val inDataa = { 
              if (dropout > 0f) Symbol.Dropout()()(Map("data" -> inData, "p" -> dropout)) 
              else inData 
                } 
        // add an hidden layer of size numHidden * 4 (numHidden set //to 28) that takes as input) 
        val i2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_i2h")()(Map("data" -> inDataa,"weight"                             -> param.i2hWeight,"bias" -> param.i2hBias,"num_hidden" -> numHidden * 4)) 
// add an hidden layer of size numHidden * 4 (numHidden set to 28) that takes output of the cell val h2h = Symbol.FullyConnected(s"t${seqIdx}_l${layerIdx}_h2h")()(Map("data" -> prevState.h,"weight" -> param.h2hWeight,"bias" -> param.h2hBias,"num_hidden" -> numHidden * 4)) //concatenate them val gates = i2h + h2h //make 4 copies of gates val sliceGates=Symbol.SliceChannel(s"t${seqIdx}_l${layerIdx}_slice")(gates)(Map("num_outputs"
-> 4))
// compute the gates val ingate = Symbol.Activation()()(Map("data" -> sliceGates.get(0), "act_type" -> "sigmoid")) val inTransform = Symbol.Activation()()(Map("data" -> sliceGates.get(1), "act_type" -> "tanh")) val forgetGate = Symbol.Activation()()(Map("data" -> sliceGates.get(2), "act_type" -> "sigmoid")) val outGate = Symbol.Activation()()(Map("data" -> sliceGates.get(3), "act_type" -> "sigmoid"))
// get the new cell state and the output val nextC = (forgetGate * prevState.c) + (ingate * inTransform) val nextH = outGate * Symbol.Activation()()(Map("data" -> nextC, "act_type" -> "tanh")) LSTMState(c = nextC, h = nextH) }
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset