First, we import the libraries:
import tensorflow as tf slim = tf.contrib.slim rnn = tf.contrib.rnn
Now, we define a class called
Matching_network, where we define our network:
We define the
__init__ method, where we initialize all of the variables:
def __init__(self, lr, n_way, k_shot, batch_size=32): #placeholder for support set self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1]) self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ]) #placeholder for query set self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1]) self.query_label = tf.placeholder(tf.int32, [None, ])
Let's say our support set and query set have images. Before feeding this raw image to the embedding function, first, we will extract the features...