Skip to content
Prev Previous commit
Next Next commit
Fix lint issues
  • Loading branch information
pomonam committed Aug 18, 2022
commit b5fb749cbda9decfe6e594b9252eb066e94ad75c
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ def scaled_init(key, shape, scale, init, dtype=jnp.float_):
fan_out,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / fan_out)))(
top_mlp_input)
bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))(
top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = nn.relu(top_mlp_input)
logits = top_mlp_input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from flax import jax_utils
import jax
import jax.numpy as jnp

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import \
models
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import metrics
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models
from algorithmic_efficiency.workloads.criteo1tb.workload import \
BaseCriteo1TbDlrmSmallWorkload

Expand Down Expand Up @@ -81,8 +81,7 @@ def _eval_metric(self, labels, logits):
def _eval_batch(self, params, batch, model_state, rng):
return super()._eval_batch(params, batch, model_state, rng)

def loss_fn(self,
label_batch: spec.Tensor,
def loss_fn(self, label_batch: spec.Tensor,
logits_batch: spec.Tensor) -> spec.Tensor:
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits=logits_batch, targets=label_batch)
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
from typing import Dict, Optional

import jax
from absl import flags
import jax

from algorithmic_efficiency import random_utils as prng
from algorithmic_efficiency import spec
Expand Down
7 changes: 5 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
import time
from typing import Optional, Tuple

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import torch
import torch.distributed as dist
from absl import app, flags, logging

from algorithmic_efficiency import halton
from algorithmic_efficiency import random_utils as prng
from algorithmic_efficiency import spec
from algorithmic_efficiency.profiler import PassThroughProfiler, Profiler
from algorithmic_efficiency.profiler import PassThroughProfiler
from algorithmic_efficiency.profiler import Profiler
from algorithmic_efficiency.pytorch_utils import pytorch_setup

# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
Expand Down