Query set embedding function

We use the embedding function for learning the embedding of our query point . We use LSTM as our encoding function. Along with as the input, we will also pass the embedding of our support set embeddings, which is g(x), and we will pass one more parameter called K, which defines the number of processing steps. Let's see how we compute query set embeddings step-by-step. First, we will initialize our LSTM cell:

cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32)

Then, for the number of processing steps, we do the following:

for step in xrange(self.processing_steps):

We calculate embeddings of the query set, , by feeding it to the LSTM cell:

    output, state = cell(XHat, prev_state)

h_k = tf.add(output, XHat)

Now, we perform softmax attention over the support set embeddings: that is, g_embedings. It helps us to avoid elements that are not required:

    content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding)) 
r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)

We update previous_state and repeat these steps for a number of processing steps, K:

prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))

The complete code for computing f_embeddings is as follows:

    def f(self, XHat, g_embedding):
cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32)

for step in xrange(self.processing_steps):
output, state = cell(XHat, prev_state)

h_k = tf.add(output, XHat)

content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))

r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)

prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))

return output
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset