Skip to content
This repository was archived by the owner on Mar 3, 2024. It is now read-only.

CyberZHG/keras-radam

Repository files navigation

Keras RAdam

Version License

[中文|English]

Unofficial implementation of RAdam in Keras.

Install

pip install keras-rectified-adam

External Link

Usage

from tensorflow import keras import numpy as np from keras_radam import RAdam # Build toy model with RAdam optimizer model = keras.models.Sequential() model.add(keras.layers.Dense(input_shape=(17,), units=3)) model.compile(RAdam(), loss='mse') # Generate toy data x = np.random.standard_normal((4096 * 30, 17)) w = np.random.standard_normal((17, 3)) y = np.dot(x, w) # Fit model.fit(x, y, epochs=5)

Use Warmup

from keras_radam import RAdam RAdam(total_steps=10000, warmup_proportion=0.1, min_lr=1e-5)

Releases

No releases published

Packages

No packages published

Contributors 6