Skip to content

Commit 9e1ef50

Browse files
Add status code to endpoint response (#172)
* add status code * filter out extraneous fields explicitly * black
1 parent 4f468aa commit 9e1ef50

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

launch/model_endpoint.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
TASK_SUCCESS_STATE = "SUCCESS"
2121
TASK_FAILURE_STATE = "FAILURE"
2222

23+
# Echoes fields in EndpointResponse class
24+
ALLOWED_ENDPOINT_RESPONSE_FIELDS = {"status", "result_url", "result", "traceback", "status_code"}
25+
2326

2427
@dataclass_json(undefined=Undefined.EXCLUDE)
2528
@dataclass
@@ -189,6 +192,7 @@ def __init__(
189192
result_url: Optional[str] = None,
190193
result: Optional[str] = None,
191194
traceback: Optional[str] = None,
195+
status_code: Optional[int] = None,
192196
):
193197
"""
194198
Parameters:
@@ -210,12 +214,15 @@ def __init__(
210214
211215
traceback: The stack trace if the inference endpoint raised an error. Can be used for debugging
212216
217+
status_code: The underlying status code of the response, given from the inference endpoint itself.
218+
213219
"""
214220
self.client = client
215221
self.status = status
216222
self.result_url = result_url
217223
self.result = result
218224
self.traceback = traceback
225+
self.status_code = status_code
219226

220227
def __str__(self) -> str:
221228
return (
@@ -271,6 +278,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
271278
result_url=async_response.get("result", {}).get("result_url", None),
272279
result=async_response.get("result", {}).get("result", None),
273280
traceback=None,
281+
status_code=async_response.get("status_code", None),
274282
)
275283
elif status == "FAILURE":
276284
return EndpointResponse(
@@ -279,6 +287,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
279287
result_url=None,
280288
result=None,
281289
traceback=async_response.get("traceback", None),
290+
status_code=async_response.get("status_code", None),
282291
)
283292
else:
284293
raise ValueError(f"Unrecognized status: {async_response['status']}")
@@ -312,6 +321,7 @@ def __next__(self):
312321
result_url=result.get("result_url", None),
313322
result=result.get("result", None),
314323
traceback=data.get("traceback"),
324+
status_code=data.get("status_code", None),
315325
)
316326

317327

@@ -397,7 +407,10 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
397407
args=request.args,
398408
return_pickled=request.return_pickled,
399409
)
400-
raw_response = {k: v for k, v in raw_response.items() if v is not None}
410+
411+
raw_response = {
412+
k: v for k, v in raw_response.items() if v is not None and k in ALLOWED_ENDPOINT_RESPONSE_FIELDS
413+
}
401414
return EndpointResponse(client=self.client, **raw_response)
402415

403416

@@ -632,6 +645,7 @@ def single_request(inner_url, inner_task_id):
632645
result_url=raw_response.get("result_url", None),
633646
result=raw_response.get("result", None),
634647
traceback=raw_response.get("traceback", None),
648+
status_code=raw_response.get("status_code", None),
635649
)
636650
self.responses[url] = response_object
637651

0 commit comments

Comments
 (0)