Skip to content

Commit c852eee

Browse files
Added generator and discriminator network balancing (disabled for now)
And other changes
1 parent a5d47a7 commit c852eee

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

script.js

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ const numParameters = 4;
77
const imageSize = 8;
88
// Number of images to use when training the neural network
99
const numTrainingImages = 1;
10-
const logData = false;
10+
const logData = true;
1111
const optimizer = {
1212
"generator": tf.train.adam(0.001),
1313
"discriminator": tf.train.sgd(0.001)
@@ -223,8 +223,8 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
223223
generateTrainingData();
224224
}
225225

226-
generatorLoss = generator.calculateLoss();
227-
discriminatorLoss = discriminator.calculateLoss();
226+
const generatorLoss = generator.calculateLoss();
227+
const discriminatorLoss = discriminator.calculateLoss();
228228

229229
if (logData) {
230230
console.log("Iteration " + iteration);
@@ -245,16 +245,11 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
245245
}
246246
document.querySelector("#iteration").innerHTML = "Iteration • " + iteration;
247247
document.querySelector("#generator-loss").innerHTML = "Generator • " +
248-
generatorLoss.
249-
dataSync()[0].
250-
toFixed(2);
248+
generatorLoss
249+
.dataSync()[0];
251250
document.querySelector("#discriminator-loss").innerHTML = "Discriminator • " +
252-
discriminatorLoss.
253-
dataSync()[0].
254-
toFixed(2);
255-
256-
generatorLoss.dispose();
257-
discriminatorLoss.dispose();
251+
discriminatorLoss
252+
.dataSync()[0];
258253

259254
const trainableVars = [];
260255
for (var i = 0; i < generator.model.weights.length; i ++) {
@@ -263,12 +258,24 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
263258
for (var i = 0; i < generator.model.model.weights.length; i ++) {
264259
trainableVars.push(generator.model.model.weights[i].val);
265260
}
261+
// if (generatorLoss > discriminatorLoss) {
266262
generator.optimizer.minimize(
267263
generator.calculateLoss,
268264
false,
269265
trainableVars
270266
);
271267
// }
268+
// else {
269+
discriminator.optimizer.minimize(discriminator.calculateLoss);
270+
// }
271+
272+
// if (discriminatorLoss < 0.05) {
273+
// generator.optimizer.minimize(generator.calculateLoss);
274+
// }
275+
// discriminator.optimizer.minimize(discriminator.calculateLoss);
276+
277+
generatorLoss.dispose();
278+
discriminatorLoss.dispose();
272279

273280
// All this is just display code
274281
// Calculate autoencoder output from original image
@@ -280,6 +287,7 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
280287
return generator.model.predict(parameters.display)
281288
// Clip pixel values to a 0 - 255 (int32) range
282289
// Reshape the output tensor into an image format (W * L * 3)
290+
.clipByValue(0, 255)
283291
.reshape(
284292
[imageSize, imageSize, 1]
285293
)
@@ -291,21 +299,21 @@ trainingData.images[trainingData.images.length - 1].onload = function () {
291299
tf.tidy(
292300
() => {
293301
return discriminator.model.predict(
294-
generator.model.predict(parameters.display).clipByValue(0, 255)
302+
generator.model.predict(parameters.display)
295303
);
296304
}
297305
);
298306

299307
if (logData) {
300308
console.log("Generator output");
301-
output.print();
309+
// output.print();
302310

303311
console.log("Discriminator output");
304312
discriminatorOutput.print();
305313
}
306314

307315
// Display the output tensor on the output canvas, then dispose the tensor
308-
tf.toPixels(output.clipByValue(0, 255), canvas.generated).then(() => output.dispose());
316+
tf.toPixels(output, canvas.generated).then(() => output.dispose());
309317
discriminatorOutput.dispose();
310318

311319
iteration ++;

0 commit comments

Comments
 (0)