ML Classifier is a machine learning engine for quickly training image classification models in your browser. Models can be saved with a single command, and the resulting models reused to make image classification predictions.
This package is intended as a companion for ml-classifier-ui
, which provides a web frontend in React for uploading data and seeing results.
A walkthrough of the code can be found in the article Image Classification in the Browser with Javascript.
An interactive demo can be found here.
ml-classifier
can be installed via yarn
or npm
:
yarn add ml-classifier
or
npm install ml-classifier
Start by instantiating a new MLClassifier.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier();
Then, train the model:
await mlClassifier.train(imageData, { callbacks: { onTrainBegin: () => { console.log('training begins'); }, onBatchEnd: (batch: any,logs: any) => { console.log('Loss is: ' + logs.loss.toFixed(5)); } }, });
And get predictions:
const prediction = await mlClassifier.predict(data);
When you have a trained model you're happy with, save it with:
mlClassifier.save();
When you hit save, Tensorflow.js will download a weights file and a model topology file.
You'll need to combine both into a single json
file. Open up your model topology file and at the top level of the JSON file, make sure to add a weightsManifest
key pointing to your weights, like:
{ "weightsManifest": "ml-classifier-class1-class2.weights.bin", "modelTopology": { ... } }
When using the model in your app, there's a few things to keep in mind:
- You need to make sure you transform images into the correct dimensions, depending on the pretrained model it was trained with. (For MOBILENET, this would be 1x224x224x3).
- You must create a pretrained model matching the dimensions used to train. An example is below for MOBILENET.
- You must first run your images through the pretrained model to activate them.
- After getting the final prediction, you must take the arg max.
- You'll get back a number indicating your class.
Full example for MOBILENET:
const loadImage = (src) => new Promise((resolve, reject) => { const image = new Image(); image.src = src; image.crossOrigin = 'Anonymous'; image.onload = () => resolve(image); image.onerror = (err) => reject(err); }); const pretrainedModelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'; tf.loadModel(pretrainedModelURL).then(model => { const layer = model.getLayer('conv_pw_13_relu'); return tf.model({ inputs: [model.inputs[0]], outputs: layer.output, }); }).then(pretrainedModel => { return tf.loadModel('/model.json').then(model => { return loadImage('/trees/tree1.png').then(loadedImage => { const image = tf.reshape(tf.fromPixels(loadedImage), [1,224,224,3]); const pretrainedModelPrediction = pretrainedModel.predict(image); const modelPrediction = model.predict(pretrainedModelPrediction); const prediction = modelPrediction.as1D().argMax().dataSync()[0]; console.log(prediction); }); }); }).catch(err => { console.error('Error', err); });
Start by instantiating a new instance of MLClassifier
with:
const mlClassifier = new MLClassifier();
This will begin loading the pretrained model and provide you with an object onto which to add data and train.
MLClassifier
accepts a number of callbacks for beginning and end of various methods.
You can provide a custom pretrained model as a pretrainedModel
.
You can provide a custom training model as a trainingModel
.
- pretrainedModel (
string | tf.Model
) Optional - A string denoting which pretrained model to load from an internal config. Valid strings can be found on the exported objectPRETRAINED_MODELS
. You can also specify a preloaded pretrained model directly. - trainingModel (
tf.Model | Function
) Optional - A custom model to use during training. Can be provided as atf.Model
or as a function that accepts{xs: [...], ys: [...]
, number ofclasses
, andparams
provided to train. - onLoadStart (
Function
) Optional - A callback for whenload
(loading the pre-trained model) is first called. - onLoadComplete (
Function
) Optional - A callback for whenload
(loading the pre-trained model) is complete. - onAddDataStart (
Function
) Optional - A callback for whenaddData
is first called. - onAddDataComplete (
Function
) Optional - A callback for whenaddData
is complete. - onClearDataStart (
Function
) Optional - A callback for whenclearData
is first called. - onClearDataComplete (
Function
) Optional - A callback for whenclearData
is complete. - onTrainStart (
Function
) Optional - A callback for whentrain
is first called. - onTrainComplete (
Function
) Optional - A callback for whentrain
is complete. - onEvaluateStart (
Function
) Optional - A callback for whenevaluate
is first called. - onEvaluateComplete (
Function
) Optional - A callback for whenevaluate
is complete. - onPredictStart (
Function
) Optional - A callback for whenpredict
is first called. - onPredictComplete (
Function
) Optional - A callback for whenpredict
is complete. - onSaveStart (
Function
) Optional - A callback for whensave
is first called. - onSaveComplete (
Function
) Optional - A callback for whensave
is complete.
import MLClassifier, { PRETRAINED_MODELS, } from 'ml-classifier'; const mlClassifier = new MLClassifier({ pretrainedModel: PRETRAINED_MODELS.MOBILENET, onLoadStart: () => console.log('onLoadStart'), onLoadComplete: () => console.log('onLoadComplete'), onAddDataStart: () => console.log('onAddDataStart'), onAddDataComplete: () => console.log('onAddDataComplete'), onClearDataStart: () => console.log('onClearDataStart'), onClearDataComplete: () => console.log('onClearDataComplete'), onTrainStart: () => console.log('onTrainStart'), onTrainComplete: () => console.log('onTrainComplete'), onEvaluateStart: () => console.log('onEvaluateStart'), onEvaluateComplete: () => console.log('onEvaluateComplete'), onPredictStart: () => console.log('onPredictStart'), onPredictComplete: () => console.log('onPredictComplete'), onSaveStart: () => console.log('onSaveStart'), onSaveComplete: () => console.log('onSaveComplete'), });
Example of specifying a preloaded pretrained model:
import MLClassifier from 'ml-classifier'; const mlClassifier = tf.loadModel('... some pretrained model ...').then(model => { return new MLClassifier({ pretrainedModel: model, }); });
This method takes an array of incoming images, an optional array of labels, and an optional dataType.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, 'train');
- images (
Array<tf.Tensor3D | ImageData | HTMLImageElement | string>
) - an array of 3D tensors, ImageData (output from a canvastoPixels
, a native browserImage
, or a string representing the imagesrc
. Images can be any sizes, but will be cropped and sized down to match the pretrained model. - labels (
string[]
) - an array of strings, matching the images passed above. - dataType (
string
) Optional - an enum specifying which data type the images match. Data types can betrain
for data used inmodel.train()
, andeval
, for data used inmodel.evaluate()
. If no argument is supplied,dataType
will default totrain
.
Nothing.
train
begins training on the given dataset.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train({ callbacks: { onTrainBegin: () => { console.log('training begins'); }, }, });
- params (
Object
) Optional - a set of parameters that will be passed directly tomodel.fit
. View the Tensorflow.JS docs for an up-to-date list of arguments.
train
returns the resolved promise from fit
, an object containing loss and accuracy.
evaluate
is used to evaluate a model's performance.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.addData(evaluationImages, labels, DataType.EVALUATE); mlClassifier.evaluate();
- params (
Object
) Optional - a set of parameters that will be passed directly tomodel.evaluate
. View the Tensorflow.JS docs for an up-to-date list of arguments.
evaluate
returns a tf.Scalar representing the result of evaluate
.
predict
is used to make a specific prediction using a saved model.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.predict(imageToPredict);
- image (
tf.Tensor3D
) - a single image encoded as atf.Tensor3D
. Image can be any size, but will be cropped and sized down to match the pretrained model.
predict
will return a string matching the prediction.
save
is a proxy to tf.model.save
, and will initiate a download from the browser, or save to local storage.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.save(('path-to-save');
- handlerOrUrl (
io.IOHandler | string
) Optional - an argument to be passed tomodel.save
. If omitted, the model's unique labels will be concatenated together in the form ofclass1-class2-class3
. - params (
Object
) Optional - a set of parameters that will be passed directly tomodel.save
. View the Tensorflow.JS docs for an up-to-date list of arguments.
getModel
will return the trained Tensorflow.js model. Calling this method prior to calling mlClassifier.train
will return null
.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.getModel();
None.
The saved Tensorflow.js model.
clearData
will clear out saved data.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.clearData(DataType.TRAIN);
- dataType (
DataType
) Optional - specifies which data to clear. If no argument is provided, all data will be cleared.
Nothing.
Contributions are welcome!
You can start up a local copy of ml-classifier
with:
yarn watch
ml-classifier
is written in Typescript.
Tests are a work in progress. Currently, the test suite only consists of unit tests. Pull requests for additional tests are welcome!
Run tests with:
yarn test
This project is licensed under the MIT License - see the LICENSE file for details