TL;DR: You need to use Session.run()
to get a Python boolean, but there are other ways to achieve the same result that might be more efficient.
It looks like you've already figured out how to get a boolean tensor from your value, but for the benefit of other readers, it would look something like this:
computed_val = ...
constant_val = tf.constant(37.0)
pred = tf.less(computed_val, constant_val) # N.B. Types of the two args must match
The next part is how to use it as a conditional. The simplest thing to do is to use a Python if
statement, but to do that you must evaluate the tensor pred
using Session.run()
:
sess = tf.Session()
if sess.run(pred):
# Do something.
else:
# Do something else.
One caveat about using a Python if
statement is that you have to evaluate the whole expression up to pred
, which makes it tricky to reuse intermediate values that have already been computed. I'd like to draw your attention to two other ways you can compute conditional expressions using TensorFlow, which don't require you to evaluate the predicate and get a Python value back.
The first way uses the tf.select()
op to conditionally pass through values from two tensors passed as arguments:
pred = tf.placeholder(tf.bool) # Can be any computed boolean expression.
val_if_true = tf.constant(28.0)
val_if_false = tf.constant(12.0)
result = tf.select(pred, val_if_true, val_if_false)
sess = tf.Session()
sess.run(result, feed_dict={pred: True}) # ==> 28.0
sess.run(result, feed_dict={pred: False}) # ==> 12.0
The tf.select()
op works element-wise on all of its arguments, which allows you to combine values from the two input tensors. See its documentation for more details. The drawback of tf.select()
is that it evaluates both val_if_true
and val_if_false
before computing the result, which might be expensive if they are complicated expressions.
The second way uses the tf.cond()
op, which conditionally evaluates one of two expressions. This is particularly useful if the expressions are expensive, and it is essential if they have side effects. The basic pattern is to specify two Python functions (or lambda expressions) that build subgraphs that will execute on the true or false branches:
# Define some large matrices
a = ...
b = ...
c = ...
pred = tf.placeholder(tf.bool)
def if_true():
return tf.matmul(a, b)
def if_false():
return tf.matmul(b, c)
# Will be `tf.cond()` in the next release.
from tensorflow.python.ops import control_flow_ops
result = tf.cond(pred, if_true, if_false)
sess = tf.Session()
sess.run(result, feed_dict={pred: True}) # ==> executes only (a x b)
sess.run(result, feed_dict={pred: False}) # ==> executes only (b x c)