Yet another way to debug TensorFlow models is to insert conditional asserts. The tf.Assert()
function takes a condition, and if the condition is false, it then prints the lists of given tensors and throws tf.errors.InvalidArgumentError
.
tf.Assert( condition, data, summarize=None, name=None )
- An assert operation does not fall in the path of the graph like the
tf.Print()
function. To make sure that thetf.Assert()
operation gets executed, we need to add it to the dependencies. For example, let us define an assertion to check that all the inputs are positive:
assert_op = tf.Assert(tf.reduce_all(tf.greater_equal(x,0)),[x])
- Add
assert_op
to the dependencies at the time of defining the model, as follows:
with tf.control_dependencies([assert_op]): # x is input layer layer = x # add hidden layers for i in range(num_layers): layer = tf.nn.relu(tf.matmul(layer, w[i]) + b[i]) # add output layer layer...