We use the embedding function for learning the embeddings of the support set. We use bidirectional LSTM as our embedding function . We can define our embedding function as follows:
def g(self, x_i):
forward_cell = rnn.BasicLSTMCell(32)
backward_cell = rnn.BasicLSTMCell(32)
outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, x_i, dtype=tf.float32)
return tf.add(tf.stack(x_i), tf.stack(outputs))