1+ /**
2+ * The data and data script is modified from TensorFlow.js's official repo.
3+ * Checkout TensorFlow.js's official tutorial for more information: https://js.tensorflow.org/tutorials/mnist.html
4+ */
5+
6+ const IMAGE_SIZE = 784 ;
7+ const NUM_CLASSES = 10 ;
8+ const NUM_DATASET_ELEMENTS = 65000 ;
9+
10+ const TRAIN_TEST_RATIO = 5 / 6 ;
11+
12+ const NUM_TRAIN_ELEMENTS = Math . floor ( TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS ) ;
13+ const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS ;
14+
15+ const MNIST_IMAGES_SPRITE_PATH =
16+ '../../assets/data/training/mnist_images.png' ;
17+ const MNIST_LABELS_PATH =
18+ '../../assets/data/training/mnist_labels_uint8' ;
19+
20+ class MnistData {
21+
22+ constructor ( ) {
23+
24+ this . shuffledTrainIndex = 0 ;
25+ this . shuffledTestIndex = 0 ;
26+
27+ }
28+
29+ async load ( ) {
30+
31+ const img = new Image ( ) ;
32+ const canvas = document . createElement ( 'canvas' ) ;
33+ const ctx = canvas . getContext ( '2d' ) ;
34+ const imgRequest = new Promise ( ( resolve , reject ) => {
35+
36+ img . crossOrigin = '' ;
37+ img . onload = ( ) => {
38+
39+ img . width = img . naturalWidth ;
40+ img . height = img . naturalHeight ;
41+
42+ const datasetBytesBuffer =
43+ new ArrayBuffer ( NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4 ) ;
44+
45+ const chunkSize = 5000 ;
46+ canvas . width = img . width ;
47+ canvas . height = chunkSize ;
48+
49+ for ( let i = 0 ; i < NUM_DATASET_ELEMENTS / chunkSize ; i ++ ) {
50+
51+ const datasetBytesView = new Float32Array (
52+
53+ datasetBytesBuffer , i * IMAGE_SIZE * chunkSize * 4 ,
54+ IMAGE_SIZE * chunkSize
55+
56+ ) ;
57+
58+ ctx . drawImage (
59+
60+ img , 0 , i * chunkSize , img . width , chunkSize , 0 , 0 , img . width ,
61+ chunkSize
62+ ) ;
63+
64+ const imageData = ctx . getImageData ( 0 , 0 , canvas . width , canvas . height ) ;
65+
66+ for ( let j = 0 ; j < imageData . data . length / 4 ; j ++ ) {
67+
68+ datasetBytesView [ j ] = imageData . data [ j * 4 ] / 255 ;
69+
70+ }
71+
72+ }
73+
74+ this . datasetImages = new Float32Array ( datasetBytesBuffer ) ;
75+
76+ resolve ( ) ;
77+
78+ } ;
79+
80+ img . src = MNIST_IMAGES_SPRITE_PATH ;
81+
82+ } ) ;
83+
84+ const labelsRequest = fetch ( MNIST_LABELS_PATH ) ;
85+ const [ imgResponse , labelsResponse ] =
86+ await Promise . all ( [ imgRequest , labelsRequest ] ) ;
87+
88+ this . datasetLabels = new Uint8Array ( await labelsResponse . arrayBuffer ( ) ) ;
89+
90+ this . trainIndices = tf . util . createShuffledIndices ( NUM_TRAIN_ELEMENTS ) ;
91+ this . testIndices = tf . util . createShuffledIndices ( NUM_TEST_ELEMENTS ) ;
92+
93+ this . trainImages =
94+ this . datasetImages . slice ( 0 , IMAGE_SIZE * NUM_TRAIN_ELEMENTS ) ;
95+ this . testImages = this . datasetImages . slice ( IMAGE_SIZE * NUM_TRAIN_ELEMENTS ) ;
96+ this . trainLabels =
97+ this . datasetLabels . slice ( 0 , NUM_CLASSES * NUM_TRAIN_ELEMENTS ) ;
98+ this . testLabels =
99+ this . datasetLabels . slice ( NUM_CLASSES * NUM_TRAIN_ELEMENTS ) ;
100+
101+ }
102+
103+ nextTrainBatch ( batchSize ) {
104+
105+ return this . nextBatch (
106+
107+ batchSize , [ this . trainImages , this . trainLabels ] , ( ) => {
108+
109+ this . shuffledTrainIndex =
110+ ( this . shuffledTrainIndex + 1 ) % this . trainIndices . length ;
111+ return this . trainIndices [ this . shuffledTrainIndex ] ;
112+
113+ }
114+
115+ ) ;
116+
117+ }
118+
119+ nextTestBatch ( batchSize ) {
120+
121+ return this . nextBatch ( batchSize , [ this . testImages , this . testLabels ] , ( ) => {
122+
123+ this . shuffledTestIndex =
124+ ( this . shuffledTestIndex + 1 ) % this . testIndices . length ;
125+
126+ return this . testIndices [ this . shuffledTestIndex ] ;
127+
128+ } ) ;
129+
130+ }
131+
132+ nextBatch ( batchSize , data , index ) {
133+
134+ const batchImagesArray = new Float32Array ( batchSize * IMAGE_SIZE ) ;
135+ const batchLabelsArray = new Uint8Array ( batchSize * NUM_CLASSES ) ;
136+
137+ for ( let i = 0 ; i < batchSize ; i ++ ) {
138+
139+ const idx = index ( ) ;
140+
141+ const image =
142+ data [ 0 ] . slice ( idx * IMAGE_SIZE , idx * IMAGE_SIZE + IMAGE_SIZE ) ;
143+ batchImagesArray . set ( image , i * IMAGE_SIZE ) ;
144+
145+ const label =
146+ data [ 1 ] . slice ( idx * NUM_CLASSES , idx * NUM_CLASSES + NUM_CLASSES ) ;
147+ batchLabelsArray . set ( label , i * NUM_CLASSES ) ;
148+
149+ }
150+
151+ const xs = tf . tensor2d ( batchImagesArray , [ batchSize , IMAGE_SIZE ] ) ;
152+ const labels = tf . tensor2d ( batchLabelsArray , [ batchSize , NUM_CLASSES ] ) ;
153+
154+ return { xs, labels } ;
155+
156+ }
157+
158+ }
0 commit comments