We use the embedding function, , for learning the embeddings of the support set. We use bidirectional LSTM as our embedding function, .
We can define our embedding function, , as follows:
def g(X):
#forward cell
forward_cell = rnn.BasicLSTMCell(32)
#backward cell
backward_cell = rnn.BasicLSTMCell(32)
#bidirectional LSTM
outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, X, dtype=tf.float32)
return tf.add(tf.stack(X), tf.stack(outputs))