Skip to content
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@
```

#### Additional Details
```bash
pip3 install -e .
```

You can also install the requirements for individual workloads, e.g. via

```bash
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@ def per_example_sigmoid_binary_cross_entropy(logits, targets):
Args:
logits: float array of shape (batch, output_shape).
targets: float array of shape (batch, output_shape).
weights: None or float array of shape (batch,).
Returns:
Sigmoid binary cross entropy computed per example, shape (batch,).
"""
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
per_example_losses = -1.0 * (targets * log_p + (1 - targets) * log_not_p)
per_example_losses = (per_example_losses).reshape(per_example_losses.shape[0],
-1)
return jnp.sum(per_example_losses, axis=-1)
losses = -1.0 * (targets * log_p + (1 - targets) * log_not_p)
return jnp.sum(losses.reshape(losses.shape[0], -1), axis=-1)
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,22 @@ class DlrmSmall(nn.Module):
Parameters:
vocab_sizes: list of vocab sizes of embedding tables.
total_vocab_sizes: sum of embedding table sizes (for jit compilation).
num_dense_features: number of dense features as the bottom mlp input.
mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
mlp_top_dims: dimensions of dense layers of the top mlp.
num_dense_features: number of dense features as the bottom mlp input.
embed_dim: embedding dimension.
keep_diags: whether to keep the diagonal terms in x @ x.T.
"""

vocab_sizes: Sequence[int]
total_vocab_sizes: int
num_dense_features: int
mlp_bottom_dims: Sequence[int] = (512, 256, 128)
mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
embed_dim: int = 128

@nn.compact
def __call__(self, x, train):
del train
embed_dim = 128

bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
Expand All @@ -74,7 +73,8 @@ def __call__(self, x, train):
bot_mlp_input = nn.relu(bot_mlp_input)
bot_mlp_output = bot_mlp_input
batch_size = bot_mlp_output.shape[0]
feature_stack = jnp.reshape(bot_mlp_output, [batch_size, -1, embed_dim])
feature_stack = jnp.reshape(bot_mlp_output,
[batch_size, -1, self.embed_dim])

# Embedding table look-up.
vocab_sizes = jnp.asarray(self.vocab_sizes, dtype=jnp.int32)
Expand All @@ -96,11 +96,12 @@ def scaled_init(key, shape, scale, init, dtype=jnp.float_):
scaled_init, scale=scale, init=jnn.initializers.uniform(scale=1.0))
embedding_table = self.param('embedding_table',
scaled_variance_scaling_init,
[self.total_vocab_sizes, embed_dim])
[self.total_vocab_sizes, self.embed_dim])

idx_lookup = jnp.reshape(idx_lookup, [-1])
embed_features = embedding_table[idx_lookup]
embed_features = jnp.reshape(embed_features, [batch_size, -1, embed_dim])
embed_features = jnp.reshape(embed_features,
[batch_size, -1, self.embed_dim])
feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
Expand All @@ -114,9 +115,8 @@ def scaled_init(key, shape, scale, init, dtype=jnp.float_):
fan_out,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
top_mlp_input)
bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))(
top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = nn.relu(top_mlp_input)
logits = top_mlp_input
Expand Down
204 changes: 44 additions & 160 deletions algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
@@ -1,160 +1,22 @@
"""Criteo1TB DLRM-Small workload implemented in Jax."""
import functools
import math
from typing import Dict, Optional, Tuple

from clu import metrics as clu_metrics
import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
import numpy as np

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.criteo1tb import input_pipeline
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import \
dlrm_small_model
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import metrics
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models
from algorithmic_efficiency.workloads.criteo1tb.workload import \
BaseCriteo1TbDlrmSmallWorkload

_NUM_DENSE_FEATURES = 13
_VOCAB_SIZES = tuple([1024 * 128] * 26)


class Criteo1TbDlrmSmallWorkload(spec.Workload):
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
"""Criteo1TB DLRM-Small Jax workload."""

def __init__(self):
self._eval_iters = {}
self._param_shapes = None
self._param_types = None
self._flax_module = dlrm_small_model.DlrmSmall(
vocab_sizes=_VOCAB_SIZES,
total_vocab_sizes=sum(_VOCAB_SIZES),
num_dense_features=_NUM_DENSE_FEATURES)

def has_reached_goal(self, eval_result: float) -> bool:
return eval_result['validation/loss'] < self.target_value

@property
def target_value(self):
return 0.12

@property
def loss_type(self):
return spec.LossType.SIGMOID_CROSS_ENTROPY

@property
def num_train_examples(self):
return 4_195_197_692

@property
def num_eval_train_examples(self):
return 100_000

@property
def num_validation_examples(self):
return 131072 * 8 # TODO(znado): finalize the validation split size.

@property
def num_test_examples(self):
return None

@property
def train_mean(self):
return 0.0

@property
def train_stddev(self):
return 1.0

@property
def max_allowed_runtime_sec(self):
return 6 * 60 * 60

@property
def eval_period_time_sec(self):
return 20 * 60

def build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
del data_rng
ds = input_pipeline.get_criteo1tb_dataset(
split=split,
data_dir=data_dir,
is_training=(split == 'train'),
global_batch_size=global_batch_size,
num_dense_features=_NUM_DENSE_FEATURES,
vocab_sizes=_VOCAB_SIZES,
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset)
for batch in iter(ds):
batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access
yield batch

@functools.partial(
jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
def eval_step_pmapped(self, params, batch):
"""Calculate evaluation metrics on a batch."""
inputs = batch['inputs']
targets = batch['targets']
weights = batch['weights']
logits = self._flax_module.apply({'params': params}, inputs, targets)
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits, targets)
return jax.lax.psum(per_example_losses), jax.lax.psum(weights)

def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str,
global_step: int = 0) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
num_batches = int(math.ceil(num_examples / global_batch_size))
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self.build_input_queue(
rng,
split,
data_dir,
global_batch_size,
num_batches,
repeat_final_dataset=True)
total_loss_numerator = 0.
total_loss_denominator = 0.
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
batch_loss_numerator, batch_loss_denominator = (
self.eval_step_pmapped(params, eval_batch).unreplicate())
total_loss_numerator += batch_loss_numerator
total_loss_denominator += batch_loss_denominator
mean_loss = total_loss_numerator / total_loss_denominator
return mean_loss.numpy()

# Return whether or not a key in spec.ParameterContainer is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

@property
def param_shapes(self):
"""The shapes of the parameters in the workload model."""
if self._param_shapes is None:
raise ValueError(
'This should not happen, workload.init_model_fn() should be called '
'before workload.param_shapes!')
return self._param_shapes

@property
def model_params_types(self):
if self._param_shapes is None:
Expand All @@ -165,37 +27,39 @@ def model_params_types(self):
self._param_types = param_utils.jax_param_types(self._param_shapes)
return self._param_types

def output_activation_fn(self,
logits_batch: spec.Tensor,
loss_type: spec.LossType) -> spec.Tensor:
"""Return the final activations of the model."""
pass

def loss_fn(
self,
label_batch: spec.Tensor, # Dense (not one-hot) labels.
logits_batch: spec.Tensor,
label_smoothing: float = 0.0,
mask_batch: Optional[spec.Tensor] = None) -> spec.Tensor:
smoothed_targets = optax.smooth_labels(label_batch, label_smoothing)
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0) -> spec.Tensor:
del label_smoothing
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits=logits_batch, targets=smoothed_targets)
logits=logits_batch, targets=label_batch)
if mask_batch is not None:
weighted_losses = per_example_losses * mask_batch
normalization = mask_batch.sum()
else:
weighted_losses = per_example_losses
normalization = label_batch.shape[0]
normalization = label_batch.shape[0]
return jnp.sum(weighted_losses, axis=-1) / normalization

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
self._model = models.DlrmSmall(
vocab_sizes=self.vocab_sizes,
total_vocab_sizes=sum(self.vocab_sizes),
num_dense_features=self.num_dense_features,
mlp_bottom_dims=(128, 128),
mlp_top_dims=(256, 128, 1),
embed_dim=64)

rng, init_rng = jax.random.split(rng)
init_fake_batch_size = 2
input_size = _NUM_DENSE_FEATURES + len(_VOCAB_SIZES)
input_size = self.num_dense_features + len(self.vocab_sizes)
input_shape = (init_fake_batch_size, input_size)
target_shape = (init_fake_batch_size, input_size)

initial_variables = jax.jit(self._flax_module.init)(
initial_variables = jax.jit(self._model.init)(
init_rng,
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))
Expand All @@ -219,9 +83,29 @@ def model_fn(
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
targets = augmented_and_preprocessed_input_batch['targets']
logits_batch = self._flax_module.apply({'params': params}, inputs, targets)
logits_batch = self._model.apply({'params': params}, inputs, targets)
return logits_batch, None

@property
def step_hint(self):
return 64_000 # TODO(znado): finalize.
@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0),
static_broadcasted_argnums=(0,))
def _eval_batch_pmapped(self, params, batch):
logits, _ = self.model_fn(
params,
batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
update_batch_norm=False)
per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy(
logits, batch['targets'])
return jnp.sum(per_example_losses)

def _eval_batch(self, params, batch):
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
batch_loss_numerator = np.sum(self._eval_batch_pmapped(params, batch))
batch_loss_denominator = np.sum(batch['weights'])
return np.asarray(batch_loss_numerator), batch_loss_denominator
Loading