Now, we will see how to build a matching network in TensorFlow step by step. We will see the final code at the end.
First, we import the libraries:
import tensorflow as tf
slim = tf.contrib.slim
rnn = tf.contrib.rnn
Now, we define a class called Matching_network, where we define our network:
class Matching_network():
We define the __init__ method, where we initialize all of the variables:
def __init__(self, lr, n_way, k_shot, batch_size=32):
#placeholder for support set
self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1])
self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ])
#placeholder for query set
self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1])
self.query_label = tf.placeholder(tf.int32, [None, ])
Let's say our support set and query set have images. Before feeding this raw image to the embedding function, first, we will extract the features from the image using a convolutional network and then we feed the extracted features of the support set and query set to the embedding functions of g and f respectively.
So, we will define a function called image_encoder, which is used for encoding features from the image. We use a four-layered convolutional network with a max pooling operation as our image encoder:
def image_encoder(self, image):
with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
#conv1
net = slim.conv2d(image)
net = slim.max_pool2d(net, [2, 2])
#conv2
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
#conv3
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
#conv4
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
return tf.reshape(net, [-1, 1 * 1 * 64])
Now we define our embedding functions; we have already seen how the embedding functions, f and g, are defined in the Embedding function section. So, we can define them directly as follows:
#embedding function for extracting support set embeddings
def g(self, x_i):
forward_cell = rnn.BasicLSTMCell(32)
backward_cell = rnn.BasicLSTMCell(32)
outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, x_i, dtype=tf.float32)
return tf.add(tf.stack(x_i), tf.stack(outputs))
#embedding function for extracting query set embeddings
def f(self, XHat, g_embedding):
cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32)
for step in xrange(self.processing_steps):
output, state = cell(XHat, prev_state)
h_k = tf.add(output, XHat)
content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))
r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)
prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))
return output
Now, we define a function called cosine_similarity for learning the cosine similarity between support set and query set embeddings:
def cosine_similarity(self, target, support_set):
target_normed = target
sup_similarity = []
for i in tf.unstack(support_set):
i_normed = tf.nn.l2_normalize(i, 1)
similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2))
sup_similarity.append(similarity)
return tf.squeeze(tf.stack(sup_similarity, axis=1))
Finally, we use a function called train to perform our training operation—let's see this step by step:
def train(self, support_set_image, support_set_label, query_image):
First, we encode the features of support set images using our image encoder:
support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]
Then, we will also encode the features of query set images using the image encoder:
query_image_encoded = self.image_encoder(query_image)
Next, we will learn the embeddings of our support set using our embedding function, :
g_embedding = self.g(support_set_image_encoded)
Similarly, we will also learn the embeddings of our query set using our embedding function, f:
f_embedding = self.f(query_image_encoded, g_embedding)
Now, we calculate cosine_similarity between both of these embeddings:
embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding)
Then, we perform softmax attention over this similarity:
attention = tf.nn.softmax(embeddings_similarity)
We predict a query set label by multiplying our attention matrix with one-hot encoded support set labels:
y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
Next, we get probabilities:
probabilities = tf.squeeze(y_hat)
We select the index that has the highest probability as a class of the query image:
predictions = tf.argmax(self.logits, 1)
Finally, we define our loss function; we use softmax cross-entropy as our loss function:
loss_function = tf.losses.sparse_softmax_cross_entropy(label, self.probabilities)
We minimize our loss function using AdamOptimizer:
tf.train.AdamOptimizer(self.lr).minimize(self.loss_op)
Now, we will see the final code of our matching network as a whole:
class Matching_network():
#initialize all the variables
def __init__(self, lr, n_way, k_shot, batch_size=32):
#placeholder for support set
self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1])
self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ])
#placeholder for query set
self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1])
self.query_label = tf.placeholder(tf.int32, [None, ])
#encoder function for extracting features from the image
def image_encoder(self, image):
with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
#conv1
net = slim.conv2d(image)
net = slim.max_pool2d(net, [2, 2])
#conv2
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
#conv3
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
#conv4
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
return tf.reshape(net, [-1, 1 * 1 * 64])
#embedding function for extracting support set embeddings
def g(self, x_i):
forward_cell = rnn.BasicLSTMCell(32)
backward_cell = rnn.BasicLSTMCell(32)
outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, x_i, dtype=tf.float32)
return tf.add(tf.stack(x_i), tf.stack(outputs))
#embedding function for extracting query set embeddings
def f(self, XHat, g_embedding):
cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32)
for step in xrange(self.processing_steps):
output, state = cell(XHat, prev_state)
h_k = tf.add(output, XHat)
content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))
r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)
prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))
return output
#cosine similarity function for calculating cosine similarity between support set and query set embeddings
def cosine_similarity(self, target, support_set):
target_normed = target
sup_similarity = []
for i in tf.unstack(support_set):
i_normed = tf.nn.l2_normalize(i, 1)
similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2))
sup_similarity.append(similarity)
return tf.squeeze(tf.stack(sup_similarity, axis=1))
def train(self, support_set_image, support_set_label, query_image):
#encode the features of query set images using our image encoder
query_image_encoded = self.image_encoder(query_image)
#encode the features of support set images using our image encoder
support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]
#generate support set embeddings using our embedding function g
g_embedding = self.g(support_set_image_encoded)
#generate query set embeddings using our embedding function f
f_embedding = self.f(query_image_encoded, g_embedding)
#calculate the cosine similarity between both of these embeddings
embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding)
#perform attention over the embedding similarity
attention = tf.nn.softmax(embeddings_similarity)
#now predict query set label by multiplying attention matrix with one hot encoded support set labels
y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
#get the probabilities
probabilities = tf.squeeze(y_hat)
#select the index which has the highest probability as a class of query image
predictions = tf.argmax(self.probabilities, 1)
#we use softmax cross entropy loss as our loss function
loss_function = tf.losses.sparse_softmax_cross_entropy(label, self.probabilities)
#we minimize the loss using adam optimizer
tf.train.AdamOptimizer(self.lr).minimize(self.loss_op)