A flutter plugin for pytorch model inference, supported both for Android and iOS.
To use this plugin, add pytorch_mobile
as a dependency in your pubspec.yaml file.
Create a assets
folder with your pytorch model and labels if needed. Modify pubspec.yaml
accoringly.
assets: - assets/models/model.pt - assets/labels.csv
Run flutter pub get
import 'package:pytorch_mobile/pytorch_mobile.dart';
Either custom model:
Model customModel = await PyTorchMobile .loadModel('assets/models/custom_model.pt');
Or image model:
Model imageModel = await PyTorchMobile .loadModel('assets/models/resnet18.pt');
List prediction = await customModel .getPrediction([1, 2, 3, 4], [1, 2, 2], DType.float32);
String prediction = await _imageModel .getImagePrediction(image, 224, 224, "assets/labels/labels.csv");
final mean = [0.5, 0.5, 0.5]; final std = [0.5, 0.5, 0.5]; String prediction = await _imageModel .getImagePrediction(image, 224, 224, "assets/labels/labels.csv", mean: mean, std: std);