@@ -14,32 +14,42 @@ public static class RegressionExtensions
1414 {
1515 public static RegressionResult AutoFit ( this RegressionContext context ,
1616 IDataView trainData ,
17- string label ,
18- IDataView validationData ,
19- AutoFitSettings settings = null ,
20- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
17+ string label = DefaultColumnNames . Label ,
18+ IDataView validationData = null ,
19+ uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
20+ IEstimator < ITransformer > preFeaturizers = null ,
21+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
2122 CancellationToken cancellationToken = default ,
2223 IProgress < RegressionIterationResult > iterationCallback = null )
2324 {
25+ var settings = new AutoFitSettings ( ) ;
26+ settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
27+
2428 return AutoFit ( context , trainData , label , validationData , settings ,
25- purposeOverrides , cancellationToken , iterationCallback , null ) ;
29+ preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
2630 }
2731
2832 internal static RegressionResult AutoFit ( this RegressionContext context ,
2933 IDataView trainData ,
30- string label ,
31- IDataView validationData ,
34+ string label = DefaultColumnNames . Label ,
35+ IDataView validationData = null ,
3236 AutoFitSettings settings = null ,
33- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
37+ IEstimator < ITransformer > preFeaturizers = null ,
38+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
3439 CancellationToken cancellationToken = default ,
3540 IProgress < RegressionIterationResult > iterationCallback = null ,
3641 IDebugLogger debugLogger = null )
3742 {
38- UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , purposeOverrides ) ;
43+ UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
44+
45+ if ( validationData == null )
46+ {
47+ ( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
48+ }
3949
4050 // run autofit & get all pipelines run in that process
4151 var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
42- settings , TaskKind . Regression , OptimizingMetric . RSquared , purposeOverrides , debugLogger ) ;
52+ settings , preFeaturizers , TaskKind . Regression , OptimizingMetric . RSquared , columnPurposes , debugLogger ) ;
4353
4454 var results = new RegressionIterationResult [ allPipelines . Length ] ;
4555 for ( var i = 0 ; i < results . Length ; i ++ )
@@ -57,33 +67,43 @@ public static class BinaryClassificationExtensions
5767 {
5868 public static BinaryClassificationResult AutoFit ( this BinaryClassificationContext context ,
5969 IDataView trainData ,
60- string label ,
61- IDataView validationData ,
62- AutoFitSettings settings = null ,
63- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
70+ string label = DefaultColumnNames . Label ,
71+ IDataView validationData = null ,
72+ uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
73+ IEstimator < ITransformer > preFeaturizers = null ,
74+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
6475 CancellationToken cancellationToken = default ,
6576 IProgress < BinaryClassificationItertionResult > iterationCallback = null )
6677 {
78+ var settings = new AutoFitSettings ( ) ;
79+ settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
80+
6781 return AutoFit ( context , trainData , label , validationData , settings ,
68- purposeOverrides , cancellationToken , iterationCallback , null ) ;
82+ preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
6983 }
7084
7185 internal static BinaryClassificationResult AutoFit ( this BinaryClassificationContext context ,
7286 IDataView trainData ,
73- string label ,
74- IDataView validationData ,
87+ string label = DefaultColumnNames . Label ,
88+ IDataView validationData = null ,
7589 AutoFitSettings settings = null ,
76- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
90+ IEstimator < ITransformer > preFeaturizers = null ,
91+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
7792 CancellationToken cancellationToken = default ,
7893 IProgress < BinaryClassificationItertionResult > iterationCallback = null ,
7994 IDebugLogger debugLogger = null )
8095 {
81- UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , purposeOverrides ) ;
96+ UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
97+
98+ if ( validationData == null )
99+ {
100+ ( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
101+ }
82102
83103 // run autofit & get all pipelines run in that process
84104 var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
85- settings , TaskKind . BinaryClassification , OptimizingMetric . Accuracy ,
86- purposeOverrides , debugLogger ) ;
105+ settings , preFeaturizers , TaskKind . BinaryClassification , OptimizingMetric . Accuracy ,
106+ columnPurposes , debugLogger ) ;
87107
88108 var results = new BinaryClassificationItertionResult [ allPipelines . Length ] ;
89109 for ( var i = 0 ; i < results . Length ; i ++ )
@@ -101,32 +121,42 @@ public static class MulticlassExtensions
101121 {
102122 public static MulticlassClassificationResult AutoFit ( this MulticlassClassificationContext context ,
103123 IDataView trainData ,
104- string label ,
105- IDataView validationData ,
106- AutoFitSettings settings = null ,
107- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
124+ string label = DefaultColumnNames . Label ,
125+ IDataView validationData = null ,
126+ uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
127+ IEstimator < ITransformer > preFeaturizers = null ,
128+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
108129 CancellationToken cancellationToken = default ,
109130 IProgress < MulticlassClassificationIterationResult > iterationCallback = null )
110131 {
132+ var settings = new AutoFitSettings ( ) ;
133+ settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
134+
111135 return AutoFit ( context , trainData , label , validationData , settings ,
112- purposeOverrides , cancellationToken , iterationCallback , null ) ;
136+ preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
113137 }
114138
115139 internal static MulticlassClassificationResult AutoFit ( this MulticlassClassificationContext context ,
116140 IDataView trainData ,
117- string label ,
118- IDataView validationData ,
141+ string label = DefaultColumnNames . Label ,
142+ IDataView validationData = null ,
119143 AutoFitSettings settings = null ,
120- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
144+ IEstimator < ITransformer > preFeaturizers = null ,
145+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
121146 CancellationToken cancellationToken = default ,
122147 IProgress < MulticlassClassificationIterationResult > iterationCallback = null , IDebugLogger debugLogger = null )
123148 {
124- UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , purposeOverrides ) ;
149+ UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
150+
151+ if ( validationData == null )
152+ {
153+ ( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
154+ }
125155
126156 // run autofit & get all pipelines run in that process
127157 var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
128- settings , TaskKind . MulticlassClassification , OptimizingMetric . Accuracy ,
129- purposeOverrides , debugLogger ) ;
158+ settings , preFeaturizers , TaskKind . MulticlassClassification , OptimizingMetric . Accuracy ,
159+ columnPurposes , debugLogger ) ;
130160
131161 var results = new MulticlassClassificationIterationResult [ allPipelines . Length ] ;
132162 for ( var i = 0 ; i < results . Length ; i ++ )
@@ -142,39 +172,39 @@ internal static MulticlassClassificationResult AutoFit(this MulticlassClassifica
142172
143173 public class BinaryClassificationResult
144174 {
145- public readonly BinaryClassificationItertionResult BestPipeline ;
175+ public readonly BinaryClassificationItertionResult BestIteration ;
146176 public readonly BinaryClassificationItertionResult [ ] IterationResults ;
147177
148178 public BinaryClassificationResult ( BinaryClassificationItertionResult bestPipeline ,
149179 BinaryClassificationItertionResult [ ] iterationResults )
150180 {
151- BestPipeline = bestPipeline ;
181+ BestIteration = bestPipeline ;
152182 IterationResults = iterationResults ;
153183 }
154184 }
155185
156186 public class MulticlassClassificationResult
157187 {
158- public readonly MulticlassClassificationIterationResult BestPipeline ;
188+ public readonly MulticlassClassificationIterationResult BestIteration ;
159189 public readonly MulticlassClassificationIterationResult [ ] IterationResults ;
160190
161191 public MulticlassClassificationResult ( MulticlassClassificationIterationResult bestPipeline ,
162192 MulticlassClassificationIterationResult [ ] iterationResults )
163193 {
164- BestPipeline = bestPipeline ;
194+ BestIteration = bestPipeline ;
165195 IterationResults = iterationResults ;
166196 }
167197 }
168198
169199 public class RegressionResult
170200 {
171- public readonly RegressionIterationResult BestPipeline ;
201+ public readonly RegressionIterationResult BestIteration ;
172202 public readonly RegressionIterationResult [ ] IterationResults ;
173203
174204 public RegressionResult ( RegressionIterationResult bestPipeline ,
175205 RegressionIterationResult [ ] iterationResults )
176206 {
177- BestPipeline = bestPipeline ;
207+ BestIteration = bestPipeline ;
178208 IterationResults = iterationResults ;
179209 }
180210 }
0 commit comments