Skip to content

Commit 38d22a5

Browse files
committed
add lenet training playground
1 parent 7043f56 commit 38d22a5

File tree

20 files changed

+1036
-8
lines changed

20 files changed

+1036
-8
lines changed

assets/data/training/mnist_data.js

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/**
2+
* The data and data script is modified from TensorFlow.js's official repo.
3+
* Checkout TensorFlow.js's official tutorial for more information: https://js.tensorflow.org/tutorials/mnist.html
4+
*/
5+
6+
const IMAGE_SIZE = 784;
7+
const NUM_CLASSES = 10;
8+
const NUM_DATASET_ELEMENTS = 65000;
9+
10+
const TRAIN_TEST_RATIO = 5 / 6;
11+
12+
const NUM_TRAIN_ELEMENTS = Math.floor( TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS );
13+
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
14+
15+
const MNIST_IMAGES_SPRITE_PATH =
16+
'../../assets/data/training/mnist_images.png';
17+
const MNIST_LABELS_PATH =
18+
'../../assets/data/training/mnist_labels_uint8';
19+
20+
class MnistData {
21+
22+
constructor() {
23+
24+
this.shuffledTrainIndex = 0;
25+
this.shuffledTestIndex = 0;
26+
27+
}
28+
29+
async load() {
30+
31+
const img = new Image();
32+
const canvas = document.createElement( 'canvas' );
33+
const ctx = canvas.getContext( '2d' );
34+
const imgRequest = new Promise( ( resolve, reject ) => {
35+
36+
img.crossOrigin = '';
37+
img.onload = () => {
38+
39+
img.width = img.naturalWidth;
40+
img.height = img.naturalHeight;
41+
42+
const datasetBytesBuffer =
43+
new ArrayBuffer( NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4 );
44+
45+
const chunkSize = 5000;
46+
canvas.width = img.width;
47+
canvas.height = chunkSize;
48+
49+
for ( let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i ++ ) {
50+
51+
const datasetBytesView = new Float32Array(
52+
53+
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
54+
IMAGE_SIZE * chunkSize
55+
56+
);
57+
58+
ctx.drawImage(
59+
60+
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
61+
chunkSize
62+
);
63+
64+
const imageData = ctx.getImageData( 0, 0, canvas.width, canvas.height );
65+
66+
for ( let j = 0; j < imageData.data.length / 4; j ++ ) {
67+
68+
datasetBytesView[ j ] = imageData.data[ j * 4 ] / 255;
69+
70+
}
71+
72+
}
73+
74+
this.datasetImages = new Float32Array( datasetBytesBuffer );
75+
76+
resolve();
77+
78+
};
79+
80+
img.src = MNIST_IMAGES_SPRITE_PATH;
81+
82+
} );
83+
84+
const labelsRequest = fetch( MNIST_LABELS_PATH );
85+
const [ imgResponse, labelsResponse ] =
86+
await Promise.all( [ imgRequest, labelsRequest ] );
87+
88+
this.datasetLabels = new Uint8Array( await labelsResponse.arrayBuffer() );
89+
90+
this.trainIndices = tf.util.createShuffledIndices( NUM_TRAIN_ELEMENTS );
91+
this.testIndices = tf.util.createShuffledIndices( NUM_TEST_ELEMENTS );
92+
93+
this.trainImages =
94+
this.datasetImages.slice( 0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS );
95+
this.testImages = this.datasetImages.slice( IMAGE_SIZE * NUM_TRAIN_ELEMENTS );
96+
this.trainLabels =
97+
this.datasetLabels.slice( 0, NUM_CLASSES * NUM_TRAIN_ELEMENTS );
98+
this.testLabels =
99+
this.datasetLabels.slice( NUM_CLASSES * NUM_TRAIN_ELEMENTS );
100+
101+
}
102+
103+
nextTrainBatch( batchSize ) {
104+
105+
return this.nextBatch(
106+
107+
batchSize, [ this.trainImages, this.trainLabels ], () => {
108+
109+
this.shuffledTrainIndex =
110+
( this.shuffledTrainIndex + 1 ) % this.trainIndices.length;
111+
return this.trainIndices[ this.shuffledTrainIndex ];
112+
113+
}
114+
115+
);
116+
117+
}
118+
119+
nextTestBatch( batchSize ) {
120+
121+
return this.nextBatch( batchSize, [ this.testImages, this.testLabels ], () => {
122+
123+
this.shuffledTestIndex =
124+
( this.shuffledTestIndex + 1 ) % this.testIndices.length;
125+
126+
return this.testIndices[ this.shuffledTestIndex ];
127+
128+
} );
129+
130+
}
131+
132+
nextBatch( batchSize, data, index ) {
133+
134+
const batchImagesArray = new Float32Array( batchSize * IMAGE_SIZE );
135+
const batchLabelsArray = new Uint8Array( batchSize * NUM_CLASSES );
136+
137+
for ( let i = 0; i < batchSize; i ++ ) {
138+
139+
const idx = index();
140+
141+
const image =
142+
data[ 0 ].slice( idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE );
143+
batchImagesArray.set( image, i * IMAGE_SIZE );
144+
145+
const label =
146+
data[ 1 ].slice( idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES );
147+
batchLabelsArray.set( label, i * NUM_CLASSES );
148+
149+
}
150+
151+
const xs = tf.tensor2d( batchImagesArray, [ batchSize, IMAGE_SIZE ] );
152+
const labels = tf.tensor2d( batchLabelsArray, [ batchSize, NUM_CLASSES ] );
153+
154+
return { xs, labels };
155+
156+
}
157+
158+
}
10.2 MB
Loading
635 KB
Binary file not shown.

assets/img/playground/0.png

272 Bytes
Loading

assets/img/playground/1.png

168 Bytes
Loading

assets/img/playground/2.png

289 Bytes
Loading

assets/img/playground/3.png

777 Bytes
Loading

assets/img/playground/4.png

272 Bytes
Loading

assets/img/playground/5.png

800 Bytes
Loading

assets/img/playground/6.png

736 Bytes
Loading

0 commit comments

Comments
 (0)