Skip to content

Commit 6759f87

Browse files
Upload dataset before training model auto-sklearn. Fail early if there are issues with the dataset
1 parent 8844e1e commit 6759f87

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

examples/tabular-classification/sklearn/churn-classifier/churn-classifier-sklearn.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@
430430
"name": "python",
431431
"nbconvert_exporter": "python",
432432
"pygments_lexer": "ipython3",
433-
"version": "3.7.13"
433+
"version": "3.8.13"
434434
}
435435
},
436436
"nbformat": 4,

unboxapi/__init__.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,19 @@ def add_baseline(
13321332
col_names = qb.column_names
13331333
categorical_feature_names = qb.get_categorical_feature_names(train_features_df)
13341334

1335+
# Upload the validation set -- if there are issues, it's better to fail prior to model training
1336+
if val_df is not None:
1337+
self.add_dataframe(
1338+
df=val_df,
1339+
task_type=task_type,
1340+
project_id=project_id,
1341+
class_names=class_names,
1342+
label_column_name=label_column_name,
1343+
commit_message=commit_message,
1344+
feature_names=col_names,
1345+
categorical_feature_names=categorical_feature_names,
1346+
)
1347+
13351348
# Train model
13361349
print(
13371350
f"Training model for approximately {round(0.0166 * timeout, 2)} minute(s)."
@@ -1347,22 +1360,10 @@ def add_baseline(
13471360

13481361
# Create requirements file
13491362
filename = "auto-requirements.txt"
1350-
with open("auto-requirements.txt", "w") as f:
1363+
with open(filename, "w") as f:
13511364
f.write("Automunge==8.30\n")
13521365
f.write("scikit-learn== 0.24.1")
13531366

1354-
if val_df is not None:
1355-
self.add_dataframe(
1356-
df=val_df,
1357-
task_type=task_type,
1358-
project_id=project_id,
1359-
class_names=class_names,
1360-
label_column_name=label_column_name,
1361-
commit_message=commit_message,
1362-
feature_names=col_names,
1363-
categorical_feature_names=categorical_feature_names,
1364-
)
1365-
13661367
# Upload model
13671368
model_info = self.add_model(
13681369
function=predict_proba,

0 commit comments

Comments
 (0)