105105 ),
106106]
107107_TEST_TRAFFIC_SPLIT = {_TEST_ID : 0 , _TEST_ID_2 : 100 , _TEST_ID_3 : 0 }
108- _TEST_PREDICTION = [{"label" : 1.0 }]
108+ _TEST_DICT_PREDICTION = [{"label" : 1.0 }]
109+ _TEST_LIST_PREDICTION = [[1.0 ]]
109110_TEST_EXPLANATIONS = [gca_prediction_service .explanation .Explanation (attributions = [])]
110111_TEST_ATTRIBUTIONS = [
111112 gca_prediction_service .explanation .Attribution (
@@ -218,26 +219,54 @@ def get_endpoint_with_models_with_explanation_mock():
218219
219220
220221@pytest .fixture
221- def predict_client_predict_mock ():
222+ def predict_client_predict_dict_mock ():
222223 with mock .patch .object (
223224 prediction_service_client .PredictionServiceClient , "predict"
224225 ) as predict_mock :
225226 predict_mock .return_value = gca_prediction_service .PredictResponse (
226227 deployed_model_id = _TEST_ID
227228 )
228- predict_mock .return_value .predictions .extend (_TEST_PREDICTION )
229+ predict_mock .return_value .predictions .extend (_TEST_DICT_PREDICTION )
229230 yield predict_mock
230231
231232
232233@pytest .fixture
233- def predict_client_explain_mock ():
234+ def predict_client_explain_dict_mock ():
234235 with mock .patch .object (
235236 prediction_service_client .PredictionServiceClient , "explain"
236237 ) as predict_mock :
237238 predict_mock .return_value = gca_prediction_service .ExplainResponse (
238239 deployed_model_id = _TEST_ID ,
239240 )
240- predict_mock .return_value .predictions .extend (_TEST_PREDICTION )
241+ predict_mock .return_value .predictions .extend (_TEST_DICT_PREDICTION )
242+ predict_mock .return_value .explanations .extend (_TEST_EXPLANATIONS )
243+ predict_mock .return_value .explanations [0 ].attributions .extend (
244+ _TEST_ATTRIBUTIONS
245+ )
246+ yield predict_mock
247+
248+
249+ @pytest .fixture
250+ def predict_client_predict_list_mock ():
251+ with mock .patch .object (
252+ prediction_service_client .PredictionServiceClient , "predict"
253+ ) as predict_mock :
254+ predict_mock .return_value = gca_prediction_service .PredictResponse (
255+ deployed_model_id = _TEST_ID
256+ )
257+ predict_mock .return_value .predictions .extend (_TEST_LIST_PREDICTION )
258+ yield predict_mock
259+
260+
261+ @pytest .fixture
262+ def predict_client_explain_list_mock ():
263+ with mock .patch .object (
264+ prediction_service_client .PredictionServiceClient , "explain"
265+ ) as predict_mock :
266+ predict_mock .return_value = gca_prediction_service .ExplainResponse (
267+ deployed_model_id = _TEST_ID ,
268+ )
269+ predict_mock .return_value .predictions .extend (_TEST_LIST_PREDICTION )
241270 predict_mock .return_value .explanations .extend (_TEST_EXPLANATIONS )
242271 predict_mock .return_value .explanations [0 ].attributions .extend (
243272 _TEST_ATTRIBUTIONS
@@ -312,10 +341,112 @@ def test_create_lit_model_from_tensorflow_with_xai_returns_model(
312341 assert len (item .values ()) == 2
313342
314343 @pytest .mark .usefixtures (
315- "predict_client_predict_mock" , "get_endpoint_with_models_mock"
344+ "predict_client_predict_dict_mock" , "get_endpoint_with_models_mock"
345+ )
346+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
347+ def test_create_lit_model_from_dict_endpoint_returns_model (
348+ self , feature_types , label_types , model_id
349+ ):
350+ endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
351+ lit_model = create_lit_model_from_endpoint (
352+ endpoint , feature_types , label_types , model_id
353+ )
354+ test_inputs = [
355+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
356+ ]
357+ outputs = lit_model .predict_minibatch (test_inputs )
358+
359+ assert lit_model .input_spec () == dict (feature_types )
360+ assert lit_model .output_spec () == dict (label_types )
361+ assert len (outputs ) == 1
362+ for item in outputs :
363+ assert item .keys () == {"label" }
364+ assert len (item .values ()) == 1
365+
366+ @pytest .mark .usefixtures (
367+ "predict_client_explain_dict_mock" ,
368+ "get_endpoint_with_models_with_explanation_mock" ,
369+ )
370+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
371+ def test_create_lit_model_from_dict_endpoint_with_xai_returns_model (
372+ self , feature_types , label_types , model_id
373+ ):
374+ endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
375+ lit_model = create_lit_model_from_endpoint (
376+ endpoint , feature_types , label_types , model_id
377+ )
378+ test_inputs = [
379+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
380+ ]
381+ outputs = lit_model .predict_minibatch (test_inputs )
382+
383+ assert lit_model .input_spec () == dict (feature_types )
384+ assert lit_model .output_spec () == dict (
385+ {
386+ ** label_types ,
387+ "feature_attribution" : lit_types .FeatureSalience (signed = True ),
388+ }
389+ )
390+ assert len (outputs ) == 1
391+ for item in outputs :
392+ assert item .keys () == {"label" , "feature_attribution" }
393+ assert len (item .values ()) == 2
394+
395+ @pytest .mark .usefixtures (
396+ "predict_client_predict_dict_mock" , "get_endpoint_with_models_mock"
397+ )
398+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
399+ def test_create_lit_model_from_dict_endpoint_name_returns_model (
400+ self , feature_types , label_types , model_id
401+ ):
402+ lit_model = create_lit_model_from_endpoint (
403+ _TEST_ENDPOINT_NAME , feature_types , label_types , model_id
404+ )
405+ test_inputs = [
406+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
407+ ]
408+ outputs = lit_model .predict_minibatch (test_inputs )
409+
410+ assert lit_model .input_spec () == dict (feature_types )
411+ assert lit_model .output_spec () == dict (label_types )
412+ assert len (outputs ) == 1
413+ for item in outputs :
414+ assert item .keys () == {"label" }
415+ assert len (item .values ()) == 1
416+
417+ @pytest .mark .usefixtures (
418+ "predict_client_explain_dict_mock" ,
419+ "get_endpoint_with_models_with_explanation_mock" ,
420+ )
421+ @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
422+ def test_create_lit_model_from_dict_endpoint_name_with_xai_returns_model (
423+ self , feature_types , label_types , model_id
424+ ):
425+ lit_model = create_lit_model_from_endpoint (
426+ _TEST_ENDPOINT_NAME , feature_types , label_types , model_id
427+ )
428+ test_inputs = [
429+ {"feature_1" : 1.0 , "feature_2" : 2.0 },
430+ ]
431+ outputs = lit_model .predict_minibatch (test_inputs )
432+
433+ assert lit_model .input_spec () == dict (feature_types )
434+ assert lit_model .output_spec () == dict (
435+ {
436+ ** label_types ,
437+ "feature_attribution" : lit_types .FeatureSalience (signed = True ),
438+ }
439+ )
440+ assert len (outputs ) == 1
441+ for item in outputs :
442+ assert item .keys () == {"label" , "feature_attribution" }
443+ assert len (item .values ()) == 2
444+
445+ @pytest .mark .usefixtures (
446+ "predict_client_predict_list_mock" , "get_endpoint_with_models_mock"
316447 )
317448 @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
318- def test_create_lit_model_from_endpoint_returns_model (
449+ def test_create_lit_model_from_list_endpoint_returns_model (
319450 self , feature_types , label_types , model_id
320451 ):
321452 endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
@@ -335,10 +466,11 @@ def test_create_lit_model_from_endpoint_returns_model(
335466 assert len (item .values ()) == 1
336467
337468 @pytest .mark .usefixtures (
338- "predict_client_explain_mock" , "get_endpoint_with_models_with_explanation_mock"
469+ "predict_client_explain_list_mock" ,
470+ "get_endpoint_with_models_with_explanation_mock" ,
339471 )
340472 @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
341- def test_create_lit_model_from_endpoint_with_xai_returns_model (
473+ def test_create_lit_model_from_list_endpoint_with_xai_returns_model (
342474 self , feature_types , label_types , model_id
343475 ):
344476 endpoint = aiplatform .Endpoint (_TEST_ENDPOINT_NAME )
@@ -363,10 +495,10 @@ def test_create_lit_model_from_endpoint_with_xai_returns_model(
363495 assert len (item .values ()) == 2
364496
365497 @pytest .mark .usefixtures (
366- "predict_client_predict_mock " , "get_endpoint_with_models_mock"
498+ "predict_client_predict_list_mock " , "get_endpoint_with_models_mock"
367499 )
368500 @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
369- def test_create_lit_model_from_endpoint_name_returns_model (
501+ def test_create_lit_model_from_list_endpoint_name_returns_model (
370502 self , feature_types , label_types , model_id
371503 ):
372504 lit_model = create_lit_model_from_endpoint (
@@ -385,10 +517,11 @@ def test_create_lit_model_from_endpoint_name_returns_model(
385517 assert len (item .values ()) == 1
386518
387519 @pytest .mark .usefixtures (
388- "predict_client_explain_mock" , "get_endpoint_with_models_with_explanation_mock"
520+ "predict_client_explain_list_mock" ,
521+ "get_endpoint_with_models_with_explanation_mock" ,
389522 )
390523 @pytest .mark .parametrize ("model_id" , [None , _TEST_ID ])
391- def test_create_lit_model_from_endpoint_name_with_xai_returns_model (
524+ def test_create_lit_model_from_list_endpoint_name_with_xai_returns_model (
392525 self , feature_types , label_types , model_id
393526 ):
394527 lit_model = create_lit_model_from_endpoint (
0 commit comments