Skip to content

Commit b6f7e01

Browse files
committed
enable time limit in automlx
1 parent 5eca230 commit b6f7e01

File tree

1 file changed

+14
-6
lines changed
  • ads/opctl/operator/lowcode/forecast/model

1 file changed

+14
-6
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def _build_model(self) -> pd.DataFrame:
6161
model_kwargs_cleaned.get("score_metric", AUTOMLX_DEFAULT_SCORE_METRIC),
6262
)
6363
model_kwargs_cleaned.pop("task", None)
64+
time_budget = model_kwargs_cleaned.pop("time_budget", 0)
6465
model_kwargs_cleaned[
6566
"preprocessing"
6667
] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
@@ -88,7 +89,11 @@ def _build_model(self) -> pd.DataFrame:
8889
task="forecasting",
8990
**model_kwargs_cleaned,
9091
)
91-
model.fit(X=y_train.drop(target, axis=1), y=pd.DataFrame(y_train[target]))
92+
model.fit(
93+
X=y_train.drop(target, axis=1),
94+
y=pd.DataFrame(y_train[target]),
95+
time_budget=time_budget,
96+
)
9297
logger.info("Selected model: {}".format(model.selected_model_))
9398
logger.info(
9499
"Selected model params: {}".format(model.selected_model_params_)
@@ -157,7 +162,6 @@ def _build_model(self) -> pd.DataFrame:
157162
self.data = data_merged
158163
return outputs_merged
159164

160-
161165
@runtime_dependency(
162166
module="datapane",
163167
err_msg=(
@@ -234,18 +238,22 @@ def _generate_report(self):
234238

235239
local_explanation_text = dp.Text(f"## Local Explanation of Models \n ")
236240
blocks = [
237-
dp.Table(local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100, label=s_id)
241+
dp.Table(
242+
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
243+
label=s_id,
244+
)
238245
for s_id, local_ex_df in self.local_explanation.items()
239246
]
240-
local_explanation_section = dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
247+
local_explanation_section = (
248+
dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
249+
)
241250

242251
# Append the global explanation text and section to the "all_sections" list
243252
all_sections = all_sections + [
244253
global_explanation_text,
245254
global_explanation_section,
246255
local_explanation_text,
247-
local_explanation_section
248-
,
256+
local_explanation_section,
249257
]
250258

251259
model_description = dp.Text(

0 commit comments

Comments
 (0)