Skip to content

Commit 7ecb1fe

Browse files
More training tweaks
- Activation functions - Optimizers - Log loss function - Increased maximum training speed - Discriminator output range is now 0-1
1 parent dfbb373 commit 7ecb1fe

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

script.js

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ const numParameters = 1;
77
const imageSize = 32;
88
// Number of images to use when training the neural network
99
const numTrainingImages = 15;
10-
const logData = true;
10+
const logData = false;
1111
const optimizer = {
12-
"generator": tf.train.adam(0.01),
13-
"discriminator": tf.train.adam(0.01)
12+
"generator": tf.train.adam(0.0001),
13+
"discriminator": tf.train.sgd(0.0001)
1414
}
1515

1616
// Automatically generated settings and parameters
1717
// Volume of image data, calculated by squaring imageSize to find the area of the image (total number of pixels) and multiplying by three for each color channel (RGB)
1818
const imageVolume = (imageSize ** 2) * 1;
19-
const numLayers = 8;
19+
const numLayers = 10;
2020
// Get information for canvas
2121
const canvas = document.getElementById("canvas");
2222
// Get context for canvas
@@ -46,7 +46,7 @@ const generator = {
4646
// Do we need these tidys?
4747
() => {
4848
// Evaluate the loss function given the output of the autoencoder network and the actual image
49-
return loss(
49+
return tf.losses.logLoss(
5050
discriminator.model.predict(
5151
generator.model.predict(parameters.training).clipByValue(0, 255)
5252
),
@@ -95,7 +95,7 @@ if (logData) {
9595
discriminator.model.add(tf.layers.dense({units: imageVolume, inputShape: [imageVolume]}));
9696
for (var i = 0; i < numLayers; i ++) {
9797
const layerSize = Math.round(imageVolume / (2 ** (i + 1)));
98-
discriminator.model.add(tf.layers.dense({units: layerSize, activation: "tanh"}));
98+
discriminator.model.add(tf.layers.dense({units: layerSize, activation: "sigmoid"}));
9999
if (logData) {
100100
console.log(layerSize);
101101
}
@@ -183,7 +183,7 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
183183
);
184184
trainingData.pixels.input.push(generatedArray);
185185

186-
outputValues = new Array(numParameters).fill(-1);
186+
outputValues = new Array(numParameters).fill(0);
187187
trainingData.pixels.output.push(outputValues);
188188
}
189189

@@ -257,7 +257,6 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
257257
// Decode the low-dimensional representation of the input data created by the encoder
258258
return generator.model.predict(parameters.display)
259259
// Clip pixel values to a 0 - 255 (int32) range
260-
.clipByValue(0, 255)
261260
// Reshape the output tensor into an image format (W * L * 3)
262261
.reshape(
263262
[imageSize, imageSize, 1]
@@ -284,13 +283,13 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
284283
}
285284

286285
// Display the output tensor on the output canvas, then dispose the tensor
287-
tf.toPixels(output, canvas).then(() => output.dispose());
286+
tf.toPixels(output.clipByValue(0, 255), canvas).then(() => output.dispose());
288287
discriminatorOutput.dispose();
289288

290289
iteration ++;
291290
}
292291
// Set an interval of 100 milliseconds to repeat the train() function
293-
var interval = window.setInterval(train, 10);
292+
var interval = window.setInterval(train, 1);
294293
}
295294
// Load source paths for training data images (this must be done after the image elements are created and the onload function is defined)
296295
// Loop through each image element

0 commit comments

Comments
 (0)