Skip to content
91 changes: 91 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
name: Flutter CI

# This workflow is triggered on pushes to the repository.

on:
push:
branches:
- master

# on: push # Default will running for every branch.

jobs:
tests:
name: Testing
# This job will run on ubuntu virtual machine
runs-on: ubuntu-latest
steps:

# Setup Java environment in order to build the Android app.
- uses: actions/checkout@v1
- uses: actions/setup-java@v1
with:
java-version: '12.x'

# Setup the flutter environment.
- uses: subosito/flutter-action@v1
with:
channel: 'dev' # 'dev', 'alpha', default to: 'stable'
# flutter-version: '1.12.x' # you can also specify exact version of flutter

# Get flutter dependencies.
- run: flutter pub get

# Check for any formatting issues in the code.
- run: flutter format --set-exit-if-changed .

# Statically analyze the Dart code for any errors.
- run: flutter analyze .

# Run widget tests for our flutter project.
- run: flutter test

build_android:
name: Build Flutter(Android)
# This job will run on ubuntu virtual machine
runs-on: ubuntu-latest
steps:

# Setup Java environment in order to build the Android app.
- uses: actions/checkout@v1
- uses: actions/setup-java@v1
with:
java-version: '12.x'

# Setup the flutter environment.
- uses: subosito/flutter-action@v1
with:
channel: 'dev' # 'dev', 'alpha', default to: 'stable'
# flutter-version: '1.12.x' # you can also specify exact version of flutter

# Get flutter dependencies.
- run: flutter pub get

# Build apk.
- run: flutter build apk
working-directory: ./example

build_ios:
name: Build Flutter(iOS)
# This job will run on macOS virtual machine
runs-on: macOS-latest
steps:

# Setup Java environment in order to build the iOS app.
- uses: actions/checkout@v1
- uses: actions/setup-java@v1
with:
java-version: '12.x'

# Setup the flutter environment.
- uses: subosito/flutter-action@v1
with:
channel: 'dev' # 'dev', 'alpha', default to: 'stable'
# flutter-version: '1.12.x' # you can also specify exact version of flutter

# Get flutter dependencies.
- run: flutter pub get

# Build apk.
- run: flutter build ios --release --no-codesign
working-directory: ./example
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@
## 0.2.0

* Null safety, bugfixes and PyTorch version update

## 0.2.1

* Custom mean and std
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,15 @@ String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv", mean: mean, std: std);
```

### Detectron2 model [only detection]
```dart
List<List>? prediction = await _d2model
.detectron2(image, 320, 320, "assets/labels/d2go.csv", minScore: 0.4);

// prediction[0] => [left, top, right, bottom, score, label]
```

#### Detectron2 model is generated with [d2go](https://github.com/facebookresearch/d2go), using [script](create_d2go.py)

## Contact
fynnmaarten.business@gmail.com
250 changes: 173 additions & 77 deletions android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.facebook.soloader.nativeloader.SystemDelegate;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/** TorchMobilePlugin */
public class PyTorchMobilePlugin implements FlutterPlugin, MethodCallHandler {
Expand Down Expand Up @@ -53,102 +55,196 @@ public static void registerWith(Registrar registrar) {
public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) {
switch (call.method){
case "loadModel":
try {
String absPath = call.argument("absPath");
modules.add(Module.load(absPath));
result.success(modules.size() - 1);
} catch (Exception e) {
String assetPath = call.argument("assetPath");
Log.e("PyTorchMobile", assetPath + " is not a proper model", e);
}
loadModel(call, result);
break;

case "predict":
Module module = null;
Integer[] shape = null;
Double[] data = null;
DType dtype = null;
predict(call, result);
break;

try{
int index = call.argument("index");
module = modules.get(index);
case "predictImage":
predictImage(call, result);
break;

// Detectron2
case "detectron2":
detectron2(call, result);
break;

dtype = DType.valueOf(call.argument("dtype").toString().toUpperCase());
default:
result.notImplemented();
break;
}
}

ArrayList<Integer> shapeList = call.argument("shape");
shape = shapeList.toArray(new Integer[shapeList.size()]);
// Functions
private void loadModel(@NonNull MethodCall call, @NonNull Result result)
{
try {
String absPath = call.argument("absPath");
modules.add(Module.load(absPath));
result.success(modules.size() - 1);
} catch (Exception e) {
String assetPath = call.argument("assetPath");
Log.e("PyTorchMobile", assetPath + " is not a proper model", e);
}
}

ArrayList<Double> dataList = call.argument("data");
data = dataList.toArray(new Double[dataList.size()]);
private void predict(@NonNull MethodCall call, @NonNull Result result)
{
Module module = null;
Integer[] shape = null;
Double[] data = null;
DType dtype = null;

}catch(Exception e){
Log.e("PyTorchMobile", "error parsing arguments", e);
}
try{
int index = call.argument("index");
module = modules.get(index);

//prepare input tensor
final Tensor inputTensor = getInputTensor(dtype, data, shape);
dtype = DType.valueOf(call.argument("dtype").toString().toUpperCase());

//run model
Tensor outputTensor = null;
try {
outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
}catch(RuntimeException e){
Log.e("PyTorchMobile", "Your input type " + dtype.toString().toLowerCase() + " (" + Convert.dtypeAsPrimitive(dtype.toString()) +") " + "does not match with model input type",e);
result.success(null);
}
ArrayList<Integer> shapeList = call.argument("shape");
shape = shapeList.toArray(new Integer[shapeList.size()]);

successResult(result, dtype, outputTensor);
ArrayList<Double> dataList = call.argument("data");
data = dataList.toArray(new Double[dataList.size()]);

break;
case "predictImage":
Module imageModule = null;
Bitmap bitmap = null;
float [] mean = null;
float [] std = null;
try {
int index = call.argument("index");
byte[] imageData = call.argument("image");
int width = call.argument("width");
int height = call.argument("height");
// Custom mean
ArrayList<Double> _mean = call.argument("mean");
mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0]));

// Custom std
ArrayList<Double> _std = call.argument("std");
std = Convert.toFloatPrimitives(_std.toArray(new Double[0]));



imageModule = modules.get(index);

bitmap = BitmapFactory.decodeByteArray(imageData,0,imageData.length);

bitmap = Bitmap.createScaledBitmap(bitmap, width, height, false);

}catch (Exception e){
Log.e("PyTorchMobile", "error reading image", e);
}
}catch(Exception e){
Log.e("PyTorchMobile", "error parsing arguments", e);
}

final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
mean, std);
//prepare input tensor
final Tensor inputTensor = getInputTensor(dtype, data, shape);

final Tensor imageOutputTensor = imageModule.forward(IValue.from(imageInputTensor)).toTensor();
//run model
Tensor outputTensor = null;
try {
outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
}catch(RuntimeException e){
Log.e("PyTorchMobile", "Your input type " + dtype.toString().toLowerCase() + " (" + Convert.dtypeAsPrimitive(dtype.toString()) +") " + "does not match with model input type",e);
result.success(null);
}

float[] scores = imageOutputTensor.getDataAsFloatArray();
successResult(result, dtype, outputTensor);

ArrayList<Float> out = new ArrayList<>();
for(float f : scores){
out.add(f);
}
}

result.success(out);
private void predictImage(@NonNull MethodCall call, @NonNull Result result)
{
Module imageModule = null;
Bitmap bitmap = null;
float [] mean = null;
float [] std = null;
try {
int index = call.argument("index");
byte[] imageData = call.argument("image");
int width = call.argument("width");
int height = call.argument("height");
// Custom mean
ArrayList<Double> _mean = call.argument("mean");
mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0]));

// Custom std
ArrayList<Double> _std = call.argument("std");
std = Convert.toFloatPrimitives(_std.toArray(new Double[0]));



imageModule = modules.get(index);

bitmap = BitmapFactory.decodeByteArray(imageData,0,imageData.length);

bitmap = Bitmap.createScaledBitmap(bitmap, width, height, false);

}catch (Exception e){
Log.e("PyTorchMobile", "error reading image", e);
}

break;
default:
result.notImplemented();
break;
final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
mean, std);

final Tensor imageOutputTensor = imageModule.forward(IValue.from(imageInputTensor)).toTensor();

float[] scores = imageOutputTensor.getDataAsFloatArray();

ArrayList<Float> out = new ArrayList<>();
for(float f : scores){
out.add(f);
}

result.success(out);

}
private void detectron2(@NonNull MethodCall call, @NonNull Result result)
{
Module imageModule = null;
Bitmap bitmap = null;
float [] mean = null;
float [] std = null;
double minScore = 0.0;
int width = 640;
int height = 640;
try {
int index = call.argument("index");
byte[] imageData = call.argument("image");
width = call.argument("width");
height = call.argument("height");
// Custom mean
ArrayList<Double> _mean = call.argument("mean");
mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0]));

// Custom std
ArrayList<Double> _std = call.argument("std");
std = Convert.toFloatPrimitives(_std.toArray(new Double[0]));

minScore = call.argument("minScore");

imageModule = modules.get(index);

bitmap = BitmapFactory.decodeByteArray(imageData,0,imageData.length);

bitmap = Bitmap.createScaledBitmap(bitmap, width, height, false);

}catch (Exception e){
Log.e("PyTorchMobile", "error reading image", e);
}

final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
mean, std);
final Map<String, IValue> map = imageModule.forward(IValue.from(imageInputTensor)).toDictStringKey();

// Return list
List< List<Float> > out = new ArrayList<>();

final Tensor boxesTensor = map.get("boxes").toTensor();
final Tensor scoresTensor = map.get("scores").toTensor();
final Tensor labelsTensor = map.get("labels").toTensor();
float[] boxesData = boxesTensor.getDataAsFloatArray();
float[] scoresData = scoresTensor.getDataAsFloatArray();
long[] labelsData = labelsTensor.getDataAsLongArray();

final int n = scoresData.length;
for (int i = 0; i < n; i++) {
if (scoresData[i] < minScore)
continue;

List<Float> detection = new ArrayList<>(6);

detection.add(boxesData[4 * i + 0]); // left
detection.add(boxesData[4 * i + 1]); // top
detection.add(boxesData[4 * i + 2]); // right
detection.add(boxesData[4 * i + 3]); // bottom
detection.add(scoresData[i]); // score
detection.add((float)(labelsData[i] - 1)); // label

out.add(detection);
}

result.success(out);

}
// [END] Functions


//returns input tensor depending on dtype
private Tensor getInputTensor(DType dtype, Double[] data, Integer[] shape){
Expand Down
Loading