Now we will compute the primary capsules, which extract the basic features, and the digit capsules, which recognizes the digits.
Start the TensorFlow Graph:
graph = tf.Graph()
with graph.as_default() as g:
Define the placeholders for input and output:
x = tf.placeholder(tf.float32, [batch_size, 784])
y = tf.placeholder(tf.float32, [batch_size,10])
x_image = tf.reshape(x, [-1,28,28,1])
Perform the convolution operation and get the convolutional input:
with tf.name_scope('convolutional_input'):
input_data = tf.contrib.layers.conv2d(inputs=x_image, num_outputs=256, kernel_size=9, padding='valid')
Compute the primary capsules that extract the basic features, such as edges. First, compute the capsules using the convolution operation as follows:
capsules = []
for i in range(8):
with tf.name_scope('capsules_' + str(i)):
#convolution operation
output = tf.contrib.layers.conv2d(inputs=input_data, num_outputs=32,kernel_size=9, stride=2, padding='valid')
#reshape the output
output = tf.reshape(output, [batch_size, -1, 1, 1])
#store the output which is capsule in the capsules list
capsules.append(output)
Concatenate all the capsules and form the primary capsules, squash the primary capsules, and get the probability as follows:
primary_capsule = tf.concat(capsules, axis=2)
Apply the squash function to the primary capsules and get the probability:
primary_capsule = squash(primary_capsule)
Compute the digit capsules using a dynamic-routing algorithm as follows:
with tf.name_scope('dynamic_routing'):
#reshape the primary capsule
outputs = tf.reshape(primary_capsule, shape=(batch_size, -1, 1, primary_capsule.shape[-2].value, 1))
#initialize bij with 0s
bij = tf.constant(np.zeros([1, primary_capsule.shape[1].value, 10, 1, 1], dtype=np.float32))
#compute the digit capsules using dynamic routing algorithm which takes
#the reshaped primary capsules and bij as inputs and returns the activity vector
digit_capsules = dynamic_routing(outputs, bij)
digit_capsules = tf.squeeze(digit_capsules, axis=1)