Computing primary and digit capsules

Now we will compute the primary capsules, which extract the basic features, and the digit capsules, which recognizes the digits.

Start the TensorFlow Graph:

graph = tf.Graph()
with graph.as_default() as g:

Define the placeholders for input and output:

    x = tf.placeholder(tf.float32, [batch_size, 784])
y = tf.placeholder(tf.float32, [batch_size,10])
x_image = tf.reshape(x, [-1,28,28,1])

Perform the convolution operation and get the convolutional input:

    with tf.name_scope('convolutional_input'):
input_data = tf.contrib.layers.conv2d(inputs=x_image, num_outputs=256, kernel_size=9, padding='valid')

Compute the primary capsules that extract the basic features, such as edges. First, compute the capsules using the convolution operation as follows:

 capsules = []

for i in range(8):

with tf.name_scope('capsules_' + str(i)):

#convolution operation
output = tf.contrib.layers.conv2d(inputs=input_data, num_outputs=32,kernel_size=9, stride=2, padding='valid')

#reshape the output
output = tf.reshape(output, [batch_size, -1, 1, 1])

#store the output which is capsule in the capsules list
capsules.append(output)

Concatenate all the capsules and form the primary capsules, squash the primary capsules, and get the probability as follows:

 primary_capsule = tf.concat(capsules, axis=2)

Apply the squash function to the primary capsules and get the probability:

 primary_capsule = squash(primary_capsule)

Compute the digit capsules using a dynamic-routing algorithm as follows:

    with tf.name_scope('dynamic_routing'):

#reshape the primary capsule
outputs = tf.reshape(primary_capsule, shape=(batch_size, -1, 1, primary_capsule.shape[-2].value, 1))

#initialize bij with 0s
bij = tf.constant(np.zeros([1, primary_capsule.shape[1].value, 10, 1, 1], dtype=np.float32))



#compute the digit capsules using dynamic routing algorithm which takes
#the reshaped primary capsules and bij as inputs and returns the activity vector
digit_capsules = dynamic_routing(outputs, bij)


digit_capsules = tf.squeeze(digit_capsules, axis=1)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset