Skip to content

Commit 4bac7ea

Browse files
committed
Feat: add an api to return status for the specified model
1 parent 8dacb21 commit 4bac7ea

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ public abstract String infer(String aggregateExpression, String groupByColumn,
5353

5454
public abstract String listHyperparameters(String className, String uri) throws Exception;
5555

56+
public boolean checkAvailable(String modelName) throws Exception {
57+
return true;
58+
}
59+
5660
public Path getModelPath() {
5761
return Paths.get(TrainDBConfiguration.getTrainDBPrefixPath(), "models",
5862
modeltypeName, modelName);

traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ public String infer(String aggregateExpression, String groupByColumn, String whe
204204
return outputPath;
205205
}
206206

207+
private String unescapeString(String s) {
208+
// remove beginning/ending double quotes and unescape
209+
return StringEscapeUtils.unescapeJava(s.replaceAll("^\"|\"$", ""));
210+
}
211+
207212
@Override
208213
public String listHyperparameters(String className, String uri) throws Exception {
209214
URL url = new URL(checkTrailingSlash(uri) + "modeltype/" + className + "/hyperparams");
@@ -222,8 +227,32 @@ public String listHyperparameters(String className, String uri) throws Exception
222227
response.append(line);
223228
}
224229

225-
// remove beginning/ending double quotes and unescape
226-
return StringEscapeUtils.unescapeJava(response.toString().replaceAll("^\"|\"$", ""));
230+
return unescapeString(response.toString());
231+
}
232+
233+
@Override
234+
public boolean checkAvailable(String modelName) throws Exception {
235+
MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype();
236+
URL url = new URL(checkTrailingSlash(mModeltype.getUri()) + "model/" + modelName + "/status");
237+
HttpURLConnection httpConn = (HttpURLConnection) url.openConnection();
238+
httpConn.setRequestMethod("GET");
239+
240+
if (httpConn.getResponseCode() != HttpURLConnection.HTTP_OK) {
241+
throw new TrainDBException("failed to get model status");
242+
}
243+
244+
StringBuilder response = new StringBuilder();
245+
BufferedReader reader = new BufferedReader(
246+
new InputStreamReader(httpConn.getInputStream(), StandardCharsets.UTF_8));
247+
String line;
248+
while ((line = reader.readLine()) != null) {
249+
response.append(line);
250+
}
251+
String res = unescapeString(response.toString());
252+
if (res.equalsIgnoreCase("FINISHED")) {
253+
return true;
254+
}
255+
return false;
227256
}
228257

229258
}

traindb-model

0 commit comments

Comments
 (0)