Skip to content

Commit 429258e

Browse files
committed
add common.tf_get_first_true()
1 parent e41730c commit 429258e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ def parse_results(result, unhash_dict, topk=5):
181181
prediction_results.append(current_method_prediction_results)
182182
return prediction_results
183183

184+
@staticmethod
185+
def tf_get_first_true(bool_tensor: tf.Tensor) -> tf.Tensor:
186+
bool_tensor_as_int32 = tf.cast(bool_tensor, dtype=tf.int32)
187+
cumsum = tf.cumsum(bool_tensor_as_int32, axis=-1, exclusive=False)
188+
return tf.logical_and(tf.equal(cumsum, 1), bool_tensor)
189+
184190

185191
class PredictionResults:
186192
def __init__(self, original_name):

0 commit comments

Comments
 (0)