Next, a function is created to perform the inference for the test data. The model was stored as a checkpoint in the preceding step, and it is used here for inference. The placeholders for the input data are defined, and a saver object is also defined, as follows:
def inference(test_x1, max_sent_len, batch_size=1024):
with tf.name_scope('Placeholders'):
x_pls1 = tf.placeholder(tf.int32, shape=[None, max_sent_len])
keep_prob = tf.placeholder(tf.float32) # Dropout
predict = model(x_pls1, keep_prob)
saver = tf.train.Saver()
ckpt_path = tf.train.latest_checkpoint('.')
Next, a session is created and the model is restored:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, ckpt_path)
print("Model restored.")
With the model loaded into the session, the data is passed in batches, and the predictions are stored:
prediction = []
for i in range(int(math.ceil(test_x1.shape[0] / batch_size))):
start_idx = (i * batch_size) % test_x1.shape[0]
prediction += sess.run([tf.argmax(predict, 1)],
feed_dict={x_pls1: test_x[start_idx:start_idx + batch_size, :], keep_prob:1})[0].tolist()
print(prediction)
Next, all of the functions are called, in order to preprocess the data, train the model, and perform inference on the test data:
train_x1, train_x2, train_y, val_x1, val_x2, val_y, test_x1, test_x2, max_sent_len, char_map = pre_process()
train(train_x1, train_x2, train_y, val_x1, val_x1, val_y, max_sent_len, char_map, 100, 1024)
inference(test_x1, test_x2, max_sent_len)
Once the training starts, you can see the training and the results, as follows:
Validation Epoch 25, Overall loss = 0.51399 and accuracy of 1
Epoch 26, Overall loss = 0.19037 and accuracy of 0.889
Epoch 27, Overall loss = 0.15886 and accuracy of 1
Epoch 28, Overall loss = 0.15363 and accuracy of 1
Epoch 29, Overall loss = 0.098042 and accuracy of 1
Epoch 30, Overall loss = 0.10002 and accuracy of 1
Tensor("Placeholders/Placeholder_2:0", shape=(?,), dtype=int64)
After 30 epochs, the model is able to provide 100% accuracy on the validation data. We have to see how to train a model to detect duplicates, using the Quora question pair as an example.