2121from azure .ai .generative .evaluate ._utils import _is_flow , load_jsonl , _get_artifact_dir_path
2222from azure .ai .generative .evaluate ._mlflow_log_collector import RedirectUserOutputStreams
2323
24+ from ._utils import _write_properties_to_run_history
2425
2526LOGGER = logging .getLogger (__name__ )
2627
28+
2729def _get_handler_class (
2830 asset ,
2931):
@@ -75,55 +77,58 @@ def _log_metrics(run_id, metrics):
7577
7678def evaluate (
7779 evaluation_name = None ,
78- asset = None ,
79- asset_type = None ,
80+ target = None ,
8081 data = None ,
81- truth_data = None ,
82- prediction_data = None ,
8382 task_type = None ,
84- metrics_config = None ,
85- params = None ,
86- metrics = None ,
83+ sweep_args = None ,
84+ metrics_list = None ,
85+ model_config = None ,
86+ data_mapping = None ,
8787 ** kwargs
8888):
8989 results_list = []
90+ metrics_config = {}
9091 if "tracking_uri" in kwargs :
9192 mlflow .set_tracking_uri (kwargs .get ("tracking_uri" ))
9293
93- if params :
94+ if model_config :
95+ metrics_config .update ({"openai_params" : model_config })
96+
97+ if data_mapping :
98+ metrics_config .update (data_mapping )
99+
100+ if sweep_args :
94101 import itertools
95- keys , values = zip (* params .items ())
102+ keys , values = zip (* sweep_args .items ())
96103 params_permutations_dicts = [dict (zip (keys , v )) for v in itertools .product (* values )]
97104
98105 with mlflow .start_run (run_name = evaluation_name ) as run :
99- log_param_and_tag ("_azureml.evaluation_run" , True )
106+ log_property_and_tag ("_azureml.evaluation_run" , "azure-ai-generative" )
100107 for index , params_permutations_dict in enumerate (params_permutations_dicts ):
101108 evaluation_name_variant = f"{ evaluation_name } _{ index } " if evaluation_name else f"{ run .info .run_name } _{ index } "
102109
103110 evaluation_results = _evaluate (
104111 evaluation_name = evaluation_name_variant ,
105- asset = asset ,
112+ target = target ,
106113 data = data ,
107- truth_data = truth_data ,
108- prediction_data = prediction_data ,
109114 task_type = task_type ,
110- metrics_config = metrics_config ,
115+ model_config = model_config ,
116+ data_mapping = data_mapping ,
111117 params_dict = params_permutations_dict ,
112- metrics = metrics ,
118+ metrics = metrics_list ,
113119 ** kwargs
114120 )
115121 results_list .append (evaluation_results )
116122 return results_list
117123 else :
118124 evaluation_result = _evaluate (
119125 evaluation_name = evaluation_name ,
120- asset = asset ,
126+ target = target ,
121127 data = data ,
122- truth_data = truth_data ,
123- prediction_data = prediction_data ,
124128 task_type = task_type ,
125- metrics_config = metrics_config ,
126- metrics = metrics ,
129+ model_config = model_config ,
130+ data_mapping = data_mapping ,
131+ metrics = metrics_list ,
127132 ** kwargs
128133 )
129134
@@ -132,14 +137,14 @@ def evaluate(
132137
133138def _evaluate (
134139 evaluation_name = None ,
135- asset = None ,
136- asset_type = None ,
140+ target = None ,
137141 data = None ,
138142 truth_data = None ,
139143 prediction_data = None ,
140144 task_type = None ,
141- metrics_config = None ,
142145 metrics = None ,
146+ data_mapping = None ,
147+ model_config = None ,
143148 ** kwargs
144149):
145150 try :
@@ -151,23 +156,36 @@ def _evaluate(
151156 test_data = data
152157 _data_is_file = False
153158
154- if asset is None and prediction_data is None :
155- raise Exception ("asset and prediction data cannot be null" )
159+ if "y_pred" in data_mapping :
160+ prediction_data = data_mapping .get ("y_pred" )
161+
162+ if "y_test" in data_mapping :
163+ truth_data = data_mapping .get ("y_test" )
164+
165+ if target is None and prediction_data is None :
166+ raise Exception ("target and prediction data cannot be null" )
156167
157168 if task_type not in [constants .Tasks .QUESTION_ANSWERING , constants .Tasks .CHAT_COMPLETION ]:
158169 raise Exception (f"task type { task_type } is not supported" )
159170
160- with mlflow .start_run (nested = True if mlflow .active_run () else False , run_name = evaluation_name ) as run ,\
161- RedirectUserOutputStreams (logger = LOGGER ) as _ :
162-
163- log_param_and_tag ("_azureml.evaluation_run" , True )
171+ metrics_config = {}
172+ if model_config :
173+ metrics_config .update ({"openai_params" : model_config })
174+
175+ if data_mapping :
176+ metrics_config .update (data_mapping )
177+
178+ with mlflow .start_run (nested = True if mlflow .active_run () else False , run_name = evaluation_name ) as run , \
179+ RedirectUserOutputStreams (logger = LOGGER ) as _ :
180+
181+ log_property_and_tag ("_azureml.evaluation_run" , "azure-ai-generative" )
164182 # Log input is a preview feature behind an allowlist. Uncomment this line once the feature is broadly available.
165183 # log_input(data=data, data_is_file=_data_is_file)
166184
167- asset_handler_class = _get_handler_class (asset )
185+ asset_handler_class = _get_handler_class (target )
168186
169187 asset_handler = asset_handler_class (
170- asset = asset ,
188+ asset = target ,
171189 prediction_data = prediction_data ,
172190 ground_truth = truth_data ,
173191 test_data = test_data ,
@@ -211,7 +229,7 @@ def _get_instance_table():
211229
212230 with tempfile .TemporaryDirectory () as tmpdir :
213231 for param_name , param_value in kwargs .get ("params_dict" , {}).items ():
214-
232+
215233 try :
216234 mlflow .log_param (param_name , param_value )
217235 except MlflowException as ex :
@@ -220,8 +238,9 @@ def _get_instance_table():
220238 # But since we control how params are logged, this is prob fine for now.
221239
222240 if ex .error_code == ErrorCode .Name (INVALID_PARAMETER_VALUE ):
223- LOGGER .warning (f"Parameter { param_name } value is too long to log. Truncating and logging it as an artifact." )
224-
241+ LOGGER .warning (
242+ f"Parameter { param_name } value is too long to log. Truncating and logging it as an artifact." )
243+
225244 # Truncate the value to 500 bytes and log it.
226245 truncated_value = param_value .encode ('utf-8' )[:500 ].decode ('utf-8' , 'ignore' )
227246 mlflow .log_param (param_name , truncated_value )
@@ -237,20 +256,22 @@ def _get_instance_table():
237256 eval_artifact_df = _get_instance_table ().to_json (orient = "records" , lines = True , force_ascii = False )
238257 tmp_path = os .path .join (tmpdir , "eval_results.jsonl" )
239258
240- with open (tmp_path , "w" ) as f :
259+ with open (tmp_path , "w" , encoding = "utf-8" ) as f :
241260 f .write (eval_artifact_df )
242261
243262 mlflow .log_artifact (tmp_path )
244- log_param_and_tag ("_azureml.evaluate_artifacts" , json .dumps ([{"path" : "eval_results.jsonl" , "type" : "table" }]))
263+ log_property_and_tag ("_azureml.evaluate_artifacts" ,
264+ json .dumps ([{"path" : "eval_results.jsonl" , "type" : "table" }]))
245265 mlflow .log_param ("task_type" , task_type )
246266 log_param_and_tag ("_azureml.evaluate_metric_mapping" , json .dumps (metrics_handler ._metrics_mapping_to_log ))
247267
248268 return metrics
249269
270+
250271def log_input (data , data_is_file ):
251272 try :
252- # Mlflow service supports only uri_folder, hence this is need to create a dir to log input data.
253- # once support is extended, we can revisit this logic
273+ # Mlflow service supports only uri_folder, hence this is need to create a dir to log input data.
274+ # once support is extended, we can revisit this logic
254275 with tempfile .TemporaryDirectory () as tempdir :
255276 if data_is_file :
256277 file_name = os .path .basename (data )
@@ -271,6 +292,11 @@ def log_input(data, data_is_file):
271292 LOGGER .error ("Error logging data as dataset, continuing without it" )
272293 LOGGER .exception (ex , stack_info = True )
273294
295+
274296def log_param_and_tag (key , value ):
275297 mlflow .log_param (key , value )
276298 mlflow .set_tag (key , value )
299+
300+ def log_property_and_tag (key , value , logger = LOGGER ):
301+ _write_properties_to_run_history ({key : value }, logger )
302+ mlflow .set_tag (key , value )
0 commit comments