ML Classifier is a machine learning engine for quickly training image classification models in your browser. Models can be saved with a single command, and the resulting models reused to make image classification predictions.
This package is intended as a companion for ml-classifier-ui, which provides a web frontend in React for uploading data and seeing results.
An interactive demo can be found here.
ml-classifier can be installed via yarn or npm:
yarn add ml-classifier or
npm install ml-classifier Start by instantiating a new MLClassifier.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); Then, train the model:
await mlClassifier.train(imageData, { callbacks: { onTrainBegin: () => { console.log('training begins'); }, onBatchEnd: (batch: any,logs: any) => { console.log('Loss is: ' + logs.loss.toFixed(5)); } }, }); And get predictions:
const prediction = await mlClassifier.predict(data); When you have a trained model you're happy with, save it with:
mlClassifier.save(); Start by instantiating a new instance of MLClassifier with:
const mlClassifier = new MLClassifier(); This will begin loading the pretrained model and provide you with an object onto which to add data and train.
MLClassifier accepts a number of callbacks when initialized:
- onLoadStart (
Function) Optional - A callback for whenload(loading the pre-trained model) is first called. - onLoadComplete (
Function) Optional - A callback for whenload(loading the pre-trained model) is complete. - onAddDataStart (
Function) Optional - A callback for whenaddDatais first called. - onAddDataComplete (
Function) Optional - A callback for whenaddDatais complete. - onClearDataStart (
Function) Optional - A callback for whenclearDatais first called. - onClearDataComplete (
Function) Optional - A callback for whenclearDatais complete. - onTrainStart (
Function) Optional - A callback for whentrainis first called. - onTrainComplete (
Function) Optional - A callback for whentrainis complete. - onEvaluateStart (
Function) Optional - A callback for whenevaluateis first called. - onEvaluateComplete (
Function) Optional - A callback for whenevaluateis complete. - onPredictStart (
Function) Optional - A callback for whenpredictis first called. - onPredictComplete (
Function) Optional - A callback for whenpredictis complete. - onSaveStart (
Function) Optional - A callback for whensaveis first called. - onSaveComplete (
Function) Optional - A callback for whensaveis complete.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier({ onLoadStart: () => console.log('onLoadStart'), onLoadComplete: () => console.log('onLoadComplete'), onAddDataStart: () => console.log('onAddDataStart'), onAddDataComplete: () => console.log('onAddDataComplete'), onClearDataStart: () => console.log('onClearDataStart'), onClearDataComplete: () => console.log('onClearDataComplete'), onTrainStart: () => console.log('onTrainStart'), onTrainComplete: () => console.log('onTrainComplete'), onEvaluateStart: () => console.log('onEvaluateStart'), onEvaluateComplete: () => console.log('onEvaluateComplete'), onPredictStart: () => console.log('onPredictStart'), onPredictComplete: () => console.log('onPredictComplete'), onSaveStart: () => console.log('onSaveStart'), onSaveComplete: () => console.log('onSaveComplete'), }); This method takes an array of incoming images, an optional array of labels, and an optional dataType.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, 'train'); - images (
Tensor3D[]) - an array of 3D tensors. Images can be any sizes, but will be cropped and sized down to match the pretrained model. - labels (
string[]) - an array of strings, matching the images passed above. - dataType (
string) Optional - an enum specifying which data type the images match. Data types can betrainfor data used inmodel.train(), andeval, for data used inmodel.evaluate(). If no argument is supplied,dataTypewill default totrain.
Nothing.
train begins training on the given dataset.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train({ callbacks: { onTrainBegin: () => { console.log('training begins'); }, }, }); - params (
Object) Optional - a set of parameters that will be passed directly tomodel.fit. View the Tensorflow.JS docs for an up-to-date list of arguments.
train returns the resolved promise from fit, an object containing loss and accuracy.
evaluate is used to evaluate a model's performance.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.addData(evaluationImages, labels, DataType.EVALUATE); mlClassifier.evaluate(); - params (
Object) Optional - a set of parameters that will be passed directly tomodel.evaluate. View the Tensorflow.JS docs for an up-to-date list of arguments.
evaluate returns a tf.Scalar representing the result of evaluate.
predict is used to make a specific prediction using a saved model.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.predict(imageToPredict); - image (
tf.Tensor3D) - a single image encoded as atf.Tensor3D. Image can be any size, but will be cropped and sized down to match the pretrained model.
predict will return a string matching the prediction.
save is a proxy to tf.model.save, and will initiate a download from the browser, or save to local storage.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.save(('path-to-save'); - handlerOrUrl (
io.IOHandler | string) Optional - an argument to be passed tomodel.save. If omitted, the model's unique labels will be concatenated together in the form ofclass1-class2-class3. - params (
Object) Optional - a set of parameters that will be passed directly tomodel.save. View the Tensorflow.JS docs for an up-to-date list of arguments.
getModel will return the trained Tensorflow.js model. Calling this method prior to calling mlClassifier.train will return null.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.train(); mlClassifier.getModel(); None.
The saved Tensorflow.js model.
clearData will clear out saved data.
import MLClassifier from 'ml-classifier'; const mlClassifier = new MLClassifier(); mlClassifier.addData(images, labels, DataType.TRAIN); mlClassifier.clearData(DataType.TRAIN); - dataType (
DataType) Optional - specifies which data to clear. If no argument is provided, all data will be cleared.
Nothing.
Contributions are welcome!
You can start up a local copy of ml-classifier with:
yarn watch ml-classifier is written in Typescript.
Tests are a work in progress. Currently, the test suite only consists of unit tests. Pull requests for additional tests are welcome!
Run tests with:
yarn test This project is licensed under the MIT License - see the LICENSE file for details
