Skip to content

Commit 8c87d5d

Browse files
committed
Transformer/estimator are done. Working on Transform
DataView coded, waiting on C wrappers for testing
1 parent 13cd2ab commit 8c87d5d

File tree

11 files changed

+784
-1328
lines changed

11 files changed

+784
-1328
lines changed

src/Microsoft.ML.AutoMLFeaturizers/CategoryImputer.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -378,15 +378,6 @@ protected override bool ReleaseHandle()
378378

379379
#endregion
380380

381-
#region FitResult
382-
383-
internal enum FitResult : byte
384-
{
385-
Complete = 0, Continue, ResetAndContinue
386-
}
387-
388-
#endregion
389-
390381
#region ColumnInfo
391382

392383
// REVIEW: Since we can't do overloading on the native side due to the C style exports,

src/Microsoft.ML.AutoMLFeaturizers/Common.cs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Security;
35
using System.Text;
6+
using Microsoft.Win32.SafeHandles;
47

58
namespace Microsoft.ML.AutoMLFeaturizers
69
{
10+
#region Native Function Declarations
11+
12+
#endregion
13+
714
internal enum FitResult : byte
815
{
9-
Complete = 0, Continue, ResetAndContinue
16+
Complete = 1, Continue, ResetAndContinue
1017
}
1118

1219
internal enum TypeId : byte
@@ -15,9 +22,65 @@ internal enum TypeId : byte
1522
Vector, Array, Tabular
1623
};
1724

25+
internal unsafe struct ByteArray
26+
{
27+
public byte* Data;
28+
public IntPtr DataSize;
29+
}
30+
31+
internal class ErrorInfoSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
32+
{
33+
[DllImport("Featurizers", EntryPoint = "DestroyErrorInfo", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
34+
private static extern bool DestroyErrorInfo(IntPtr error);
35+
36+
public ErrorInfoSafeHandle(IntPtr handle) : base(true)
37+
{
38+
SetHandle(handle);
39+
}
40+
41+
protected override bool ReleaseHandle()
42+
{
43+
return DestroyErrorInfo(handle);
44+
}
45+
}
46+
47+
internal class ErrorInfoStringSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
48+
{
49+
[DllImport("Featurizers", EntryPoint = "DestroyErrorInfoString", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
50+
private static extern bool DestroyErrorInfoString(IntPtr errorString, IntPtr errorStringSize);
51+
52+
private IntPtr _length;
53+
public ErrorInfoStringSafeHandle(IntPtr handle, IntPtr length) : base(true)
54+
{
55+
SetHandle(handle);
56+
_length = length;
57+
}
58+
59+
protected override bool ReleaseHandle()
60+
{
61+
return DestroyErrorInfoString(handle, _length);
62+
}
63+
}
64+
1865
internal static class CommonExtensions
1966
{
67+
[DllImport("Featurizers", EntryPoint = "GetErrorInfoString", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
68+
private static extern bool GetErrorInfoString(IntPtr error, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
2069

70+
internal static string GetErrorDetailsAndFreeNativeMemory(IntPtr errorHandle)
71+
{
72+
using (var error = new ErrorInfoSafeHandle(errorHandle))
73+
{
74+
GetErrorInfoString(errorHandle, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
75+
using (var errorString = new ErrorInfoStringSafeHandle(errorHandleString, errorHandleStringSize))
76+
{
77+
byte[] buffer = new byte[errorHandleStringSize.ToInt32()];
78+
Marshal.Copy(errorHandleString, buffer, 0, buffer.Length);
79+
80+
return Encoding.UTF8.GetString(buffer);
81+
}
82+
}
83+
}
2184
internal static TypeId GetNativeTypeIdFromType(this Type type)
2285
{
2386
if (type == typeof(byte))

src/Microsoft.ML.AutoMLFeaturizers/DateTimeTransformer.cs

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ namespace Microsoft.ML.AutoMLFeaturizers
3232
public static class DateTimeTransformerExtensionClass
3333
{
3434
public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop)
35-
=> DateTimeTransformerEstimator.Create(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop);
35+
=> new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop);
36+
37+
public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeTransformerEstimator.Countries country = DateTimeTransformerEstimator.Countries.None)
38+
=> new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country);
3639

3740
#region ColumnsProduced static extentions
3841

@@ -77,21 +80,19 @@ internal sealed class Options: TransformInputBase
7780

7881
[Argument(ArgumentType.MultipleUnique, HelpText = "Columns to drop after the DateTime Expansion", Name = "ColumnsToDrop", ShortName = "drop", SortOrder = 3)]
7982
public ColumnsProduced[] ColumnsToDrop;
83+
84+
[Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)]
85+
public Countries Country = Countries.None;
8086
}
8187

8288
#endregion
8389

84-
internal static DateTimeTransformerEstimator Create(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop)
85-
{
86-
return new DateTimeTransformerEstimator(env, inputColumnName, columnPrefix, columnsToDrop);
87-
}
88-
8990
// Using this to confirm DLL exists. If does it will just return false since no parameters are being passed.
9091
// Once we have a binary dependency on the dll we can remove this code.
9192
[DllImport("Featurizers", EntryPoint = "GetErrorInfoString"), SuppressUnmanagedCodeSecurity]
9293
private static extern bool CheckIfDllExists(IntPtr error, out IntPtr errorHandleString, out IntPtr errorHandleStringSize);
9394

94-
public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop)
95+
public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, Countries country = Countries.None)
9596
{
9697
try
9798
{
@@ -110,7 +111,8 @@ public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName
110111
{
111112
Source = inputColumnName,
112113
Prefix = columnPrefix,
113-
ColumnsToDrop = columnsToDrop
114+
ColumnsToDrop = columnsToDrop == null ? Array.Empty<ColumnsProduced>() : columnsToDrop,
115+
Country = country
114116
};
115117
}
116118

@@ -129,6 +131,7 @@ internal DateTimeTransformerEstimator(IHostEnvironment env, Options options)
129131
_host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator");
130132

131133
_options = options;
134+
_options.ColumnsToDrop = _options.ColumnsToDrop == null ? Array.Empty<ColumnsProduced>() : _options.ColumnsToDrop;
132135
}
133136

134137
public DateTimeTransformer Fit(IDataView input)
@@ -151,14 +154,23 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
151154
return new SchemaShape(columns.Values);
152155
}
153156

154-
#region Column Enums
157+
#region Enums
155158
public enum ColumnsProduced : byte
156159
{
157-
Year = 0, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear,
160+
Year = 1, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear,
158161
WeekOfMonth, QuarterOfYear, HalfOfYear, WeekIso, YearIso, MonthLabel, AmPmLabel, DayOfWeekLabel,
159162
HolidayName, IsPaidTimeOff
160163
};
161164

165+
public enum Countries : byte
166+
{
167+
None = 1,
168+
Argentina, Australia, Austria, Belarus, Belgium, Brazil, Canada, Colombia, Croatia, Czech, Denmark,
169+
England, Finland, France, Germany, Hungary, India, Ireland, IsleofMan, Italy, Japan, Mexico, Netherlands,
170+
NewZealand, NorthernIreland, Norway, Poland, Portugal, Scotland, Slovenia, SouthAfrica, Spain, Sweden, Switzerland,
171+
Ukraine, UnitedKingdom, UnitedStates, Wales
172+
}
173+
162174
#endregion
163175
}
164176

@@ -473,14 +485,17 @@ internal unsafe TimePoint(byte* rawData)
473485
private static unsafe string GetStringFromPointer(ref byte* rawData, int intPtrSize)
474486
{
475487
byte[] buffer;
476-
byte* temp = rawData + intPtrSize;
477-
long tempSize = *(long*)(temp);
478-
int itempSize = *(int*)(temp);
479488
if (intPtrSize == 4) // 32 bit machine
480489
buffer = new byte[*(uint*)(rawData + intPtrSize)];
481490
else // 64 bit machine
482491
buffer = new byte[*(ulong*)(rawData + intPtrSize)];
483492

493+
if (buffer.Length == 0)
494+
{
495+
rawData += intPtrSize * 2;
496+
return string.Empty;
497+
}
498+
484499
Marshal.Copy(new IntPtr(*(int**)rawData), buffer, 0, buffer.Length);
485500
rawData += intPtrSize * 2;
486501

@@ -772,7 +787,8 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
772787

773788
var outputColumn = (int)_parent._activeColumnMapping[iinfo];
774789

775-
return Utils.MarshalInvoke(MakeGetter<int>, ((DateTimeTransformerEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn);
790+
// Have to subtract 1 from the output column since the enum starts and 1 and not 0.
791+
return Utils.MarshalInvoke(MakeGetter<int>, ((DateTimeTransformerEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1);
776792
}
777793

778794
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
@@ -802,7 +818,7 @@ internal static class DateTimeTransformerEntrypoint
802818
public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeTransformerEstimator.Options input)
803819
{
804820
var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input);
805-
var xf = DateTimeTransformerEstimator.Create(h, input.Source, input.Prefix, input.ColumnsToDrop).Fit(input.Data).Transform(input.Data);
821+
var xf = new DateTimeTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data);
806822
return new CommonOutputs.TransformOutput()
807823
{
808824
Model = new TransformModelImpl(h, xf, input.Data),

0 commit comments

Comments
 (0)