This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.
- The tree.DecisionTreeClassifier is supported
- Supports
predict()
, - Supports
predict_proba()
whenexport_text()
configured withshow_weights=True
- Supports
- The tree.RandomForestClassifier is supported
- Supports
predict()
, - Supports
predict_proba()
whenexport_text()
configured withshow_weights=True
- Supports
<dependency> <groupId>rocks.vilaverde</groupId> <artifactId>scikit-learn-2-java</artifactId> <version>1.1.0</version> </dependency>
As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using sklearn.tree
export_text()
as shown below:
>>> from sklearn.datasets import load_iris >>> from sklearn.tree import DecisionTreeClassifier >>> from sklearn.tree import export_text >>> iris = load_iris() >>> X = iris['data'] >>> y = iris['target'] >>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2) >>> decision_tree = decision_tree.fit(X, y) >>> r = export_text(decision_tree, feature_names=iris['feature_names'], show_weights=True, max_depth=sys.maxsize) >>> print(r) |--- petal width (cm) <= 0.80 | |--- class: 0 |--- petal width (cm) > 0.80 | |--- petal width (cm) <= 1.75 | | |--- class: 1 | |--- petal width (cm) > 1.75 | | |--- class: 2
The exported text can then be executed in Java. Note that when calling export_text
it is recommended that max_depth
be set to sys.maxsize
so that the tree isn't truncated.
In this example the iris model exported using export_text
is parsed, features are created as a Java Map and the decision tree is asked to predict the class.
Reader tree = getTrainedModel("iris.model"); final Classifier<Integer> decisionTree = DecisionTreeClassifier.parse(tree, PredictionFactory.INTEGER); Features features = Features.of("sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"); FeatureVector fv = features.newSample(); fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0); Integer prediction = decisionTree.predict(fv); System.out.println(prediction.toString());
To use a RandomForestClassifier that has been trained on the Iris dataset, each of the estimators
in the classifiers need to be and exported using from sklearn.tree export export_text
as shown below:
>>> from sklearn import datasets >>> from sklearn import tree >>> from sklearn.ensemble import RandomForestClassifier >>> >>> import os >>> >>> iris = datasets.load_iris() >>> X = iris.data >>> y = iris.target >>> >>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8) >>> model = clf.fit(X, y) >>> >>> for i, t in enumerate(clf.estimators_): >>> with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1: >>> text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize) >>> file1.write(text_representation)
Once all the estimators are exported into /tmp/estimators
, you can create a TAR archive, for example:
cd /tmp/estimators tar -czvf /tmp/iris.tgz .
Then you can use the RandomForestClassifier class to parse the TAR archive.
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; ... TarArchiveInputStream tree = getArchive("iris.tgz"); final Classifier<Double> decisionTree = RandomForestClassifier.parse(tree, PredictionFactory.DOUBLE);
Testing was done using models exported using sci-kit learn version 1.1.3, but should work with newer versions of sci-kit learn.