Skip to content
Merged
237 changes: 236 additions & 1 deletion src/Microsoft.ML.Data/Transforms/ValueMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke<TKey, TValue>(DataViewSchema.Column
public abstract Delegate GetGetter(DataViewRow input, int index);

public abstract IDataView GetDataView(IHostEnvironment env);
public abstract TKey[] GetKeys<TKey>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be public? Seems they're only used by Onnx conversion, so might prefer to make them private.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The valuemap class itself is only used by the ValueMapping estimator, but since the map used is not populated until runtime, I keep the method abstract. I keep the methods public so they can used by the mapper class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ok. Well, as long as the GetKeys new method isn't accessible to end users, it's all right.

public abstract TValue[] GetValues<TValue>();
}

/// <summary>
Expand Down Expand Up @@ -962,6 +965,16 @@ private static TValue GetVector<T>(TValue value)
}

private static TValue GetValue<T>(TValue value) => value;

public override T[] GetKeys<T>()
{
return _mapping.Keys.Cast<T>().ToArray();
}
public override T[] GetValues<T>()
{
return _mapping.Values.Cast<T>().ToArray();
}

}

/// <summary>
Expand Down Expand Up @@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
return new Mapper(this, schema, _valueMap, ColumnPairs);
}

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewSchema _inputSchema;
private readonly ValueMap _valueMap;
private readonly (string outputColumnName, string inputColumnName)[] _columns;
private readonly ValueMappingTransformer _parent;
public bool CanSaveOnnx(OnnxContext ctx) => true;

internal Mapper(ValueMappingTransformer transform,
DataViewSchema inputSchema,
Expand All @@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]);
}

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
Host.CheckValue(ctx, nameof(ctx));

for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;

if (!_inputSchema.TryGetColumnIndex(inputColumnName, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(_inputSchema), "input", inputColumnName);
var type = _inputSchema[colSrc].Type;
DataViewType colType;
if (type is VectorDataViewType vectorType)
colType = new VectorDataViewType((PrimitiveDataViewType)_parent.ValueColumnType, vectorType.Dimensions);
else
colType = _parent.ValueColumnType;
string dstVariableName = ctx.AddIntermediateVariable(colType, outputColumnName);
if (!ctx.ContainsColumn(inputColumnName))
continue;

if (!SaveAsOnnxCore(ctx, ctx.GetVariableName(inputColumnName), dstVariableName))
ctx.RemoveColumn(inputColumnName, true);
}
}

private void CastInputTo<T>(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType)
{
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput");
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", itemType.RawType);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
if (itemType == TextDataViewType.Instance)
node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToString(item)));
else if (itemType == NumberDataViewType.Single)
node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToSingle(item)));
else if (itemType == NumberDataViewType.Int64)
node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToInt64(item)));

}

private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
OnnxNode node;
string opType = "LabelEncoder";
var labelEncoderInput = srcVariableName;
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
var typeValue = _valueMap.ValueColumn.Type;
var typeKey = _valueMap.KeyColumn.Type;
var kind = _valueMap.ValueColumn.Type.GetRawKind();

var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName :
(typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") :
ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput");

// The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings.
// As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView
// to strings and values of NumberDataViewType to int64s.
// String -> String mappings can't be supported.
if (typeKey == NumberDataViewType.Int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm left wondering if it's possible re refactor this big if-else statement into something that's reusable, but as discussed offline it might not be possible as each case handles things in a particular way and you might end up re-writing the if-else blocks anyway, and not saving much lines of code.

{
// To avoid a int64 -> int64 mapping, we cast keys to strings
if (typeValue is NumberDataViewType)
{
CastInputTo<Int64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
node.AddAttribute("keys_int64s", _valueMap.GetKeys<Int64>());
}
}
else if (typeKey == NumberDataViewType.Int32)
{
// To avoid a string -> string mapping, we cast keys to int64s
if (typeValue is TextDataViewType)
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.Int16)
{
if (typeValue is TextDataViewType)
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.UInt64)
{
if (typeValue is TextDataViewType)
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.UInt32)
{
if (typeValue is TextDataViewType)
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.UInt16)
{
if (typeValue is TextDataViewType)
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
else
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else if (typeKey == NumberDataViewType.Single)
{
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
{
CastInputTo<float>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
}
else
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
node.AddAttribute("keys_floats", _valueMap.GetKeys<float>());
}
}
else if (typeKey == NumberDataViewType.Double)
{
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
else
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single);
}
else if (typeKey == TextDataViewType.Instance)
{
if (typeValue == TextDataViewType.Instance)
return false;
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
node.AddAttribute("keys_strings", _valueMap.GetKeys<ReadOnlyMemory<char>>());
}
else if (typeKey == BooleanDataViewType.Instance)
{
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
{
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput");
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
castNode.AddAttribute("to", t);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
var values = Array.ConvertAll(_valueMap.GetKeys<bool>(), item => Convert.ToString(Convert.ToByte(item)));
node.AddAttribute("keys_strings", values);
}
else
CastInputTo<bool>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single);
}
else
return false;

if (typeValue == NumberDataViewType.Int64)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<long>());
}
else if (typeValue == NumberDataViewType.Int32)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<int>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.Int16)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<short>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<ulong>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<uint>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.UInt16)
{
node.AddAttribute("values_int64s", _valueMap.GetValues<ushort>().Select(item => Convert.ToInt64(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == NumberDataViewType.Single)
{
node.AddAttribute("values_floats", _valueMap.GetValues<float>());
}
else if (typeValue == NumberDataViewType.Double)
{
node.AddAttribute("values_floats", _valueMap.GetValues<double>().Select(item => Convert.ToSingle(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else if (typeValue == TextDataViewType.Instance)
{
node.AddAttribute("values_strings", _valueMap.GetValues<ReadOnlyMemory<char>>());
}
else if (typeValue == BooleanDataViewType.Instance)
{
node.AddAttribute("values_floats", _valueMap.GetValues<bool>().Select(item => Convert.ToSingle(item)));
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
castNode.AddAttribute("to", typeValue.RawType);
}
else
return false;

//Unknown keys should map to 0
node.AddAttribute("default_int64", 0);
node.AddAttribute("default_string", "");
node.AddAttribute("default_float", 0f);
return true;
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_columns.Length];
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,7 @@ private static double Round(double value, int digitsOfPrecision)
public void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6, bool isRightColumnOnnxScalar = false)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];
var leftType = leftColumn.Type.GetItemType();
var rightType = rightColumn.Type.GetItemType();

if (leftType == NumberDataViewType.SByte)
CompareSelectedColumns<sbyte>(leftColumnName, rightColumnName, left, right, isRightColumnOnnxScalar: isRightColumnOnnxScalar);
Expand Down
Loading