@@ -51,7 +51,7 @@ public class LossesHelper {
5151 * @param tf the TensorFlow Ops
5252 * @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
5353 * @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
54- * </code>.
54+ * </code> .
5555 * @param <T> the data type for the labels, predictions and result
5656 * @return LossTuple of <code>prediction</code>, <code>label</code>,<code>sampleWeight</code> will
5757 * be null. Each of them possibly has the last dimension squeezed, <code>sampleWeight</code>
@@ -77,7 +77,7 @@ public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
7777 * @param tf the TensorFlow Ops
7878 * @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
7979 * @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
80- * </code>.
80+ * </code> .
8181 * @param sampleWeights Optional sample weight(s) <code>Operand</code> whose dimensions match
8282 * <code>
8383 * prediction</code>.
@@ -179,7 +179,7 @@ private static <T extends TNumber> Operand<T> maybeExpandWeights(
179179 *
180180 * @param tf the TensorFlowOps
181181 * @param labels Label values, a <code>Tensor</code> whose dimensions match <code>predictions
182- * </code>.
182+ * </code> .
183183 * @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
184184 * @param <T> the data type for the labels, predictions and result
185185 * @return <code>labels</code> and <code>predictions</code>, possibly with last dim squeezed.
@@ -194,7 +194,7 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
194194 *
195195 * @param tf the TensorFlowOps
196196 * @param labels Label values, a <code>Operand</code> whose dimensions match <code>predictions
197- * </code>.
197+ * </code> .
198198 * @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
199199 * @param expectedRankDiff Expected result of <code>rank(predictions) - rank(labels)</code>.
200200 * @param <T> the data type for the labels, predictions and result
@@ -222,11 +222,13 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
222222 // Use dynamic rank.
223223
224224 // TODO: hold for lazy select feature,
225- // Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
225+ // Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions),
226+ // tf.rank(labels));
226227 if (predictionsRank == Shape .UNKNOWN_SIZE && Shape .isCompatible (predictionsShape .size (-1 ), 1 )) {
227228 /*
228- * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
229- * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
229+ * TODO, if we ever get a select that does lazy evaluation, but for now do the
230+ * tf.squeeze predictions = tf.select(
231+ * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
230232 * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
231233 */
232234 predictions = tf .squeeze (predictions , Squeeze .axis (Collections .singletonList (-1L )));
@@ -282,11 +284,12 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
282284 if (reduction == Reduction .NONE ) {
283285 loss = weightedLoss ;
284286 } else {
285- loss =
286- tf .reduceSum (weightedLoss , allAxes (tf , weightedLoss ), ReduceSum .keepDims (Boolean .FALSE ));
287287 if (reduction == Reduction .AUTO || reduction == Reduction .SUM_OVER_BATCH_SIZE ) {
288- loss = safeMean (tf , loss , weightedLoss .shape ().size ());
289- }
288+ loss = safeMean (tf , weightedLoss );
289+ } else
290+ loss =
291+ tf .reduceSum (
292+ weightedLoss , allAxes (tf , weightedLoss ), ReduceSum .keepDims (Boolean .FALSE ));
290293 }
291294 return loss ;
292295 }
@@ -301,10 +304,10 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
301304 * @return A scalar representing the mean of <code>losses</code>. If <code>numElements</code> is
302305 * zero, then zero is returned.
303306 */
304- public static <T extends TNumber > Operand <T > safeMean (
305- Ops tf , Operand <T > losses , long numElements ) {
306- Operand < T > totalLoss = tf .reduceSum (losses , allAxes (tf , losses ));
307- return tf .math .divNoNan (totalLoss , cast (tf , tf .constant ( numElements ), losses .type ()));
307+ public static <T extends TNumber > Operand <T > safeMean (Ops tf , Operand < T > losses ) {
308+ Operand <T > totalLoss =
309+ tf .reduceSum (losses , allAxes (tf , losses ), ReduceSum . keepDims ( Boolean . FALSE ));
310+ return tf .math .divNoNan (totalLoss , cast (tf , tf .shape . size ( tf . shape ( losses ) ), losses .type ()));
308311 }
309312
310313 /**
@@ -348,7 +351,8 @@ public static <T extends TNumber> Operand<T> rangeCheck(
348351 tf .math .logicalAnd (
349352 tf .reduceAll (tf .math .greaterEqual (values , minValue ), allDims ),
350353 tf .reduceAll (tf .math .lessEqual (values , maxValue ), allDims ));
351- // Graph and Eager mode need to be handled differently, control dependencies are not allowed in
354+ // Graph and Eager mode need to be handled differently, control dependencies are
355+ // not allowed in
352356 // Eager mode
353357 if (tf .scope ().env ().isGraph ()) {
354358 AssertThat assertThat =
@@ -398,7 +402,8 @@ public static <T extends TNumber> Operand<T> valueCheck(
398402 } else return values ;
399403 } else { // use dynamic shape
400404 Operand <TBool > cond = tf .math .equal (tf .shape .size (tf .shape (diff .out ())), tf .constant (0 ));
401- // Graph and Eager mode need to be handled differently, control dependencies are not allowed
405+ // Graph and Eager mode need to be handled differently, control dependencies are
406+ // not allowed
402407 // in Eager mode
403408 if (tf .scope ().env ().isGraph ()) {
404409 AssertThat assertThat =
0 commit comments