Skip to content

Commit 3da10ec

Browse files
Training tweaks
1 parent 38db343 commit 3da10ec

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

script.js

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
// Define settings
44
const numParameters = 3;
5-
const numLayers = 6;
5+
const numLayers = 10;
66
// Size of input and output images in pixels (width and height)
7-
const imageSize = 8;
7+
const imageSize = 32;
88
// Number of images to use when training the neural network
99
const numTrainingImages = 15;
10-
const logData = false;
10+
const logData = true;
11+
const optimizer = {
12+
"generator": tf.train.adam(0.01),
13+
"discriminator": tf.train.adam(0.01)
14+
}
1115

1216
// Automatically generated settings and parameters
1317
// Volume of image data, calculated by squaring imageSize to find the area of the image (total number of pixels) and multiplying by three for each color channel (RGB)
@@ -43,13 +47,13 @@ const generator = {
4347
// Evaluate the loss function given the output of the autoencoder network and the actual image
4448
return loss(
4549
discriminator.model.predict(
46-
generator.model.predict(parameters.training)
50+
generator.model.predict(parameters.training).clipByValue(0, 255)
4751
),
4852
tf.ones([15, numParameters])
4953
);
5054
}
5155
),
52-
"optimizer": tf.train.adam(0.01)
56+
"optimizer": optimizer.generator
5357
};
5458

5559
if (logData) {
@@ -59,7 +63,7 @@ if (logData) {
5963
generator.model.add(tf.layers.dense({units: numParameters, inputShape: [numParameters]}));
6064
for (var i = 0; i < numLayers; i ++) {
6165
const layerSize = Math.round(imageVolume / (2 ** ((numLayers - 1) - i)));
62-
generator.model.add(tf.layers.dense({units: layerSize, activation: "tanh"}));
66+
generator.model.add(tf.layers.dense({units: layerSize, activation: "relu"}));
6367
if (logData) {
6468
console.log(layerSize);
6569
}
@@ -80,7 +84,7 @@ const discriminator = {
8084
);
8185
}
8286
),
83-
"optimizer": tf.train.adam(0.01)
87+
"optimizer": optimizer.discriminator
8488
};
8589

8690
if (logData) {
@@ -170,7 +174,7 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
170174
// Uncaught Error: Constructing tensor of shape (92160) should match the length of values (46095)
171175
const generated =
172176
tf.tidy(
173-
() => generator.model.predict(parameters.display).dataSync()
177+
() => generator.model.predict(parameters.display).clipByValue(0, 255).dataSync()
174178
);
175179
const generatedArray = [];
176180
generated.forEach(
@@ -252,20 +256,20 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
252256
// Decode the low-dimensional representation of the input data created by the encoder
253257
return generator.model.predict(parameters.display)
254258
// Clip pixel values to a 0 - 255 (int32) range
255-
.clipByValue(0, 1)
259+
.clipByValue(0, 255)
256260
// Reshape the output tensor into an image format (W * L * 3)
257261
.reshape(
258262
[imageSize, imageSize, 3]
259263
)
260264
}
261265
);
262-
// output.dtype = "int32";
266+
output.dtype = "int32";
263267

264268
const discriminatorOutput =
265269
tf.tidy(
266270
() => {
267271
return discriminator.model.predict(
268-
generator.model.predict(parameters.display)
272+
generator.model.predict(parameters.display).clipByValue(0, 255)
269273
);
270274
}
271275
);

0 commit comments

Comments
 (0)