20
20
TASK_SUCCESS_STATE = "SUCCESS"
21
21
TASK_FAILURE_STATE = "FAILURE"
22
22
23
+ # Echoes fields in EndpointResponse class
24
+ ALLOWED_ENDPOINT_RESPONSE_FIELDS = {"status" , "result_url" , "result" , "traceback" , "status_code" }
25
+
23
26
24
27
@dataclass_json (undefined = Undefined .EXCLUDE )
25
28
@dataclass
@@ -189,6 +192,7 @@ def __init__(
189
192
result_url : Optional [str ] = None ,
190
193
result : Optional [str ] = None ,
191
194
traceback : Optional [str ] = None ,
195
+ status_code : Optional [int ] = None ,
192
196
):
193
197
"""
194
198
Parameters:
@@ -210,12 +214,15 @@ def __init__(
210
214
211
215
traceback: The stack trace if the inference endpoint raised an error. Can be used for debugging
212
216
217
+ status_code: The underlying status code of the response, given from the inference endpoint itself.
218
+
213
219
"""
214
220
self .client = client
215
221
self .status = status
216
222
self .result_url = result_url
217
223
self .result = result
218
224
self .traceback = traceback
225
+ self .status_code = status_code
219
226
220
227
def __str__ (self ) -> str :
221
228
return (
@@ -271,6 +278,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
271
278
result_url = async_response .get ("result" , {}).get ("result_url" , None ),
272
279
result = async_response .get ("result" , {}).get ("result" , None ),
273
280
traceback = None ,
281
+ status_code = async_response .get ("status_code" , None ),
274
282
)
275
283
elif status == "FAILURE" :
276
284
return EndpointResponse (
@@ -279,6 +287,7 @@ def get(self, timeout: Optional[float] = None) -> EndpointResponse:
279
287
result_url = None ,
280
288
result = None ,
281
289
traceback = async_response .get ("traceback" , None ),
290
+ status_code = async_response .get ("status_code" , None ),
282
291
)
283
292
else :
284
293
raise ValueError (f"Unrecognized status: { async_response ['status' ]} " )
@@ -312,6 +321,7 @@ def __next__(self):
312
321
result_url = result .get ("result_url" , None ),
313
322
result = result .get ("result" , None ),
314
323
traceback = data .get ("traceback" ),
324
+ status_code = data .get ("status_code" , None ),
315
325
)
316
326
317
327
@@ -397,7 +407,10 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
397
407
args = request .args ,
398
408
return_pickled = request .return_pickled ,
399
409
)
400
- raw_response = {k : v for k , v in raw_response .items () if v is not None }
410
+
411
+ raw_response = {
412
+ k : v for k , v in raw_response .items () if v is not None and k in ALLOWED_ENDPOINT_RESPONSE_FIELDS
413
+ }
401
414
return EndpointResponse (client = self .client , ** raw_response )
402
415
403
416
@@ -632,6 +645,7 @@ def single_request(inner_url, inner_task_id):
632
645
result_url = raw_response .get ("result_url" , None ),
633
646
result = raw_response .get ("result" , None ),
634
647
traceback = raw_response .get ("traceback" , None ),
648
+ status_code = raw_response .get ("status_code" , None ),
635
649
)
636
650
self .responses [url ] = response_object
637
651
0 commit comments