@@ -32,7 +32,10 @@ namespace Microsoft.ML.AutoMLFeaturizers
3232 public static class DateTimeTransformerExtensionClass
3333 {
3434 public static DateTimeTransformerEstimator DateTimeTransformer ( this TransformsCatalog catalog , string inputColumnName , string columnPrefix , params DateTimeTransformerEstimator . ColumnsProduced [ ] columnsToDrop )
35- => DateTimeTransformerEstimator . Create ( CatalogUtils . GetEnvironment ( catalog ) , inputColumnName , columnPrefix , columnsToDrop ) ;
35+ => new DateTimeTransformerEstimator ( CatalogUtils . GetEnvironment ( catalog ) , inputColumnName , columnPrefix , columnsToDrop ) ;
36+
37+ public static DateTimeTransformerEstimator DateTimeTransformer ( this TransformsCatalog catalog , string inputColumnName , string columnPrefix , DateTimeTransformerEstimator . ColumnsProduced [ ] columnsToDrop = null , DateTimeTransformerEstimator . Countries country = DateTimeTransformerEstimator . Countries . None )
38+ => new DateTimeTransformerEstimator ( CatalogUtils . GetEnvironment ( catalog ) , inputColumnName , columnPrefix , columnsToDrop , country ) ;
3639
3740 #region ColumnsProduced static extentions
3841
@@ -77,21 +80,19 @@ internal sealed class Options: TransformInputBase
7780
7881 [ Argument ( ArgumentType . MultipleUnique , HelpText = "Columns to drop after the DateTime Expansion" , Name = "ColumnsToDrop" , ShortName = "drop" , SortOrder = 3 ) ]
7982 public ColumnsProduced [ ] ColumnsToDrop ;
83+
84+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Country to get holidays for. Defaults to none if not passed" , Name = "Country" , ShortName = "ctry" , SortOrder = 4 ) ]
85+ public Countries Country = Countries . None ;
8086 }
8187
8288 #endregion
8389
84- internal static DateTimeTransformerEstimator Create ( IHostEnvironment env , string inputColumnName , string columnPrefix , ColumnsProduced [ ] columnsToDrop )
85- {
86- return new DateTimeTransformerEstimator ( env , inputColumnName , columnPrefix , columnsToDrop ) ;
87- }
88-
8990 // Using this to confirm DLL exists. If does it will just return false since no parameters are being passed.
9091 // Once we have a binary dependency on the dll we can remove this code.
9192 [ DllImport ( "Featurizers" , EntryPoint = "GetErrorInfoString" ) , SuppressUnmanagedCodeSecurity ]
9293 private static extern bool CheckIfDllExists ( IntPtr error , out IntPtr errorHandleString , out IntPtr errorHandleStringSize ) ;
9394
94- public DateTimeTransformerEstimator ( IHostEnvironment env , string inputColumnName , string columnPrefix , ColumnsProduced [ ] columnsToDrop )
95+ public DateTimeTransformerEstimator ( IHostEnvironment env , string inputColumnName , string columnPrefix , ColumnsProduced [ ] columnsToDrop , Countries country = Countries . None )
9596 {
9697 try
9798 {
@@ -110,7 +111,8 @@ public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName
110111 {
111112 Source = inputColumnName ,
112113 Prefix = columnPrefix ,
113- ColumnsToDrop = columnsToDrop
114+ ColumnsToDrop = columnsToDrop == null ? Array . Empty < ColumnsProduced > ( ) : columnsToDrop ,
115+ Country = country
114116 } ;
115117 }
116118
@@ -129,6 +131,7 @@ internal DateTimeTransformerEstimator(IHostEnvironment env, Options options)
129131 _host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( "DateTimeTransformerEstimator" ) ;
130132
131133 _options = options ;
134+ _options . ColumnsToDrop = _options . ColumnsToDrop == null ? Array . Empty < ColumnsProduced > ( ) : _options . ColumnsToDrop ;
132135 }
133136
134137 public DateTimeTransformer Fit ( IDataView input )
@@ -151,14 +154,23 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
151154 return new SchemaShape ( columns . Values ) ;
152155 }
153156
154- #region Column Enums
157+ #region Enums
155158 public enum ColumnsProduced : byte
156159 {
157- Year = 0 , Month , Day , Hour , Minute , Second , AmPm , Hour12 , DayOfWeek , DayOfQuarter , DayOfYear ,
160+ Year = 1 , Month , Day , Hour , Minute , Second , AmPm , Hour12 , DayOfWeek , DayOfQuarter , DayOfYear ,
158161 WeekOfMonth , QuarterOfYear , HalfOfYear , WeekIso , YearIso , MonthLabel , AmPmLabel , DayOfWeekLabel ,
159162 HolidayName , IsPaidTimeOff
160163 } ;
161164
165+ public enum Countries : byte
166+ {
167+ None = 1 ,
168+ Argentina , Australia , Austria , Belarus , Belgium , Brazil , Canada , Colombia , Croatia , Czech , Denmark ,
169+ England , Finland , France , Germany , Hungary , India , Ireland , IsleofMan , Italy , Japan , Mexico , Netherlands ,
170+ NewZealand , NorthernIreland , Norway , Poland , Portugal , Scotland , Slovenia , SouthAfrica , Spain , Sweden , Switzerland ,
171+ Ukraine , UnitedKingdom , UnitedStates , Wales
172+ }
173+
162174 #endregion
163175 }
164176
@@ -473,14 +485,17 @@ internal unsafe TimePoint(byte* rawData)
473485 private static unsafe string GetStringFromPointer( ref byte * rawData , int intPtrSize )
474486 {
475487 byte [ ] buffer ;
476- byte * temp = rawData + intPtrSize ;
477- long tempSize = * ( long * ) ( temp ) ;
478- int itempSize = * ( int * ) ( temp ) ;
479488 if ( intPtrSize = = 4 ) // 32 bit machine
480489 buffer = new byte [ * ( uint * ) ( rawData + intPtrSize ) ] ;
481490 else // 64 bit machine
482491 buffer = new byte [ * ( ulong * ) ( rawData + intPtrSize ) ] ;
483492
493+ if ( buffer . Length == 0 )
494+ {
495+ rawData += intPtrSize * 2 ;
496+ return string . Empty ;
497+ }
498+
484499 Marshal . Copy ( new IntPtr ( * ( int * * ) rawData ) , buffer, 0 , buffer. Length ) ;
485500 rawData += intPtrSize * 2 ;
486501
@@ -772,7 +787,8 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
772787
773788 var outputColumn = ( int ) _parent . _activeColumnMapping [ iinfo ] ;
774789
775- return Utils. MarshalInvoke ( MakeGetter < int > , ( ( DateTimeTransformerEstimator . ColumnsProduced ) outputColumn ) . GetRawColumnType ( ) , input , outputColumn ) ;
790+ // Have to subtract 1 from the output column since the enum starts and 1 and not 0.
791+ return Utils. MarshalInvoke ( MakeGetter < int > , ( ( DateTimeTransformerEstimator . ColumnsProduced ) outputColumn ) . GetRawColumnType ( ) , input , outputColumn - 1 ) ;
776792 }
777793
778794 private protected override Func< int , bool > GetDependenciesCore ( Func < int , bool > activeOutput )
@@ -802,7 +818,7 @@ internal static class DateTimeTransformerEntrypoint
802818 public static CommonOutputs . TransformOutput DateTimeSplit ( IHostEnvironment env , DateTimeTransformerEstimator . Options input )
803819 {
804820 var h = EntryPointUtils . CheckArgsAndCreateHost ( env , DateTimeTransformer . ShortName , input ) ;
805- var xf = DateTimeTransformerEstimator . Create ( h , input . Source , input . Prefix , input . ColumnsToDrop ) . Fit ( input . Data ) . Transform ( input . Data ) ;
821+ var xf = new DateTimeTransformerEstimator ( h , input ) . Fit ( input . Data ) . Transform ( input . Data ) ;
806822 return new CommonOutputs . TransformOutput ( )
807823 {
808824 Model = new TransformModelImpl ( h , xf , input . Data ) ,
0 commit comments