- Notifications
You must be signed in to change notification settings - Fork 1.9k
Onnx Export for ValueMapping estimator #5577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits Select commit Hold shift + click to select a range
a5a7a2a onnx export for valuemapping estimator
0eb8e7a reformatting
3d3f34c resolving comments and adding scalar testing
7d1b86d adding key type support
ba2b727 testing mac
2416863 testing mac
0be5f42 testing mac
af7f609 testing mac
e4ea4fe testing mac
13fd3c6 restoring files
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -12,6 +12,7 @@ | |
| using Microsoft.ML.Data; | ||
| using Microsoft.ML.Data.IO; | ||
| using Microsoft.ML.Internal.Utilities; | ||
| using Microsoft.ML.Model.OnnxConverter; | ||
| using Microsoft.ML.Runtime; | ||
| using Microsoft.ML.Transforms; | ||
| | ||
| | @@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke<TKey, TValue>(DataViewSchema.Column | |
| public abstract Delegate GetGetter(DataViewRow input, int index); | ||
| | ||
| public abstract IDataView GetDataView(IHostEnvironment env); | ||
| public abstract TKey[] GetKeys<TKey>(); | ||
| public abstract TValue[] GetValues<TValue>(); | ||
| } | ||
| | ||
| /// <summary> | ||
| | @@ -962,6 +965,16 @@ private static TValue GetVector<T>(TValue value) | |
| } | ||
| | ||
| private static TValue GetValue<T>(TValue value) => value; | ||
| | ||
| public override T[] GetKeys<T>() | ||
| { | ||
| return _mapping.Keys.Cast<T>().ToArray(); | ||
| } | ||
| public override T[] GetValues<T>() | ||
| { | ||
| return _mapping.Values.Cast<T>().ToArray(); | ||
| } | ||
| | ||
| } | ||
| | ||
| /// <summary> | ||
| | @@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) | |
| return new Mapper(this, schema, _valueMap, ColumnPairs); | ||
| } | ||
| | ||
| private sealed class Mapper : OneToOneMapperBase | ||
| private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx | ||
| { | ||
| private readonly DataViewSchema _inputSchema; | ||
| private readonly ValueMap _valueMap; | ||
| private readonly (string outputColumnName, string inputColumnName)[] _columns; | ||
| private readonly ValueMappingTransformer _parent; | ||
| public bool CanSaveOnnx(OnnxContext ctx) => true; | ||
| | ||
| internal Mapper(ValueMappingTransformer transform, | ||
| DataViewSchema inputSchema, | ||
| | @@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b | |
| return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]); | ||
| } | ||
| | ||
| public void SaveAsOnnx(OnnxContext ctx) | ||
| { | ||
| const int minimumOpSetVersion = 9; | ||
| ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); | ||
| Host.CheckValue(ctx, nameof(ctx)); | ||
| | ||
| for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo) | ||
| { | ||
| string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName; | ||
| string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName; | ||
| | ||
| if (!_inputSchema.TryGetColumnIndex(inputColumnName, out int colSrc)) | ||
| throw Host.ExceptSchemaMismatch(nameof(_inputSchema), "input", inputColumnName); | ||
| var type = _inputSchema[colSrc].Type; | ||
| DataViewType colType; | ||
| if (type is VectorDataViewType vectorType) | ||
| colType = new VectorDataViewType((PrimitiveDataViewType)_parent.ValueColumnType, vectorType.Dimensions); | ||
| else | ||
| colType = _parent.ValueColumnType; | ||
| string dstVariableName = ctx.AddIntermediateVariable(colType, outputColumnName); | ||
| if (!ctx.ContainsColumn(inputColumnName)) | ||
| continue; | ||
| | ||
| if (!SaveAsOnnxCore(ctx, ctx.GetVariableName(inputColumnName), dstVariableName)) | ||
| ctx.RemoveColumn(inputColumnName, true); | ||
| } | ||
| } | ||
| | ||
| private void CastInputTo<T>(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType) | ||
| { | ||
| var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); | ||
| var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput"); | ||
| var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", itemType.RawType); | ||
| node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
| if (itemType == TextDataViewType.Instance) | ||
| node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToString(item))); | ||
| else if (itemType == NumberDataViewType.Single) | ||
| node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToSingle(item))); | ||
| else if (itemType == NumberDataViewType.Int64) | ||
| node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToInt64(item))); | ||
| | ||
| } | ||
| | ||
| private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) | ||
| { | ||
| const int minimumOpSetVersion = 9; | ||
| ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); | ||
| OnnxNode node; | ||
| string opType = "LabelEncoder"; | ||
| var labelEncoderInput = srcVariableName; | ||
| var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); | ||
| var typeValue = _valueMap.ValueColumn.Type; | ||
| var typeKey = _valueMap.KeyColumn.Type; | ||
| var kind = _valueMap.ValueColumn.Type.GetRawKind(); | ||
| | ||
| var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName : | ||
| (typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") : | ||
| ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput"); | ||
| | ||
| // The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings. | ||
| // As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView | ||
| // to strings and values of NumberDataViewType to int64s. | ||
| // String -> String mappings can't be supported. | ||
| if (typeKey == NumberDataViewType.Int64) | ||
| Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm left wondering if it's possible re refactor this big if-else statement into something that's reusable, but as discussed offline it might not be possible as each case handles things in a particular way and you might end up re-writing the if-else blocks anyway, and not saving much lines of code. | ||
| { | ||
| // To avoid a int64 -> int64 mapping, we cast keys to strings | ||
| if (typeValue is NumberDataViewType) | ||
| { | ||
| CastInputTo<Int64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else | ||
| { | ||
| node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
| node.AddAttribute("keys_int64s", _valueMap.GetKeys<Int64>()); | ||
| } | ||
| } | ||
| else if (typeKey == NumberDataViewType.Int32) | ||
| { | ||
| // To avoid a string -> string mapping, we cast keys to int64s | ||
| if (typeValue is TextDataViewType) | ||
| CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
| else | ||
| CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else if (typeKey == NumberDataViewType.Int16) | ||
| { | ||
| if (typeValue is TextDataViewType) | ||
| CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
| else | ||
| CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else if (typeKey == NumberDataViewType.UInt64) | ||
| { | ||
| if (typeValue is TextDataViewType) | ||
| CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
| else | ||
| CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else if (typeKey == NumberDataViewType.UInt32) | ||
| { | ||
| if (typeValue is TextDataViewType) | ||
| CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
| else | ||
| CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else if (typeKey == NumberDataViewType.UInt16) | ||
| { | ||
| if (typeValue is TextDataViewType) | ||
| CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); | ||
| else | ||
| CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else if (typeKey == NumberDataViewType.Single) | ||
| { | ||
| if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) | ||
| { | ||
| CastInputTo<float>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| } | ||
| else | ||
| { | ||
| node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
| node.AddAttribute("keys_floats", _valueMap.GetKeys<float>()); | ||
| } | ||
| } | ||
| else if (typeKey == NumberDataViewType.Double) | ||
| { | ||
| if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) | ||
| CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); | ||
| else | ||
| CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); | ||
| } | ||
| else if (typeKey == TextDataViewType.Instance) | ||
| { | ||
| if (typeValue == TextDataViewType.Instance) | ||
| return false; | ||
| node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
| node.AddAttribute("keys_strings", _valueMap.GetKeys<ReadOnlyMemory<char>>()); | ||
| } | ||
| else if (typeKey == BooleanDataViewType.Instance) | ||
| { | ||
| if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) | ||
| { | ||
| var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput"); | ||
| var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); | ||
| var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); | ||
| castNode.AddAttribute("to", t); | ||
| node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); | ||
| var values = Array.ConvertAll(_valueMap.GetKeys<bool>(), item => Convert.ToString(Convert.ToByte(item))); | ||
| node.AddAttribute("keys_strings", values); | ||
| } | ||
| else | ||
| CastInputTo<bool>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); | ||
| } | ||
| else | ||
| return false; | ||
| | ||
| if (typeValue == NumberDataViewType.Int64) | ||
| { | ||
| node.AddAttribute("values_int64s", _valueMap.GetValues<long>()); | ||
| } | ||
| else if (typeValue == NumberDataViewType.Int32) | ||
| { | ||
| node.AddAttribute("values_int64s", _valueMap.GetValues<int>().Select(item => Convert.ToInt64(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else if (typeValue == NumberDataViewType.Int16) | ||
| { | ||
| node.AddAttribute("values_int64s", _valueMap.GetValues<short>().Select(item => Convert.ToInt64(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8) | ||
| { | ||
| node.AddAttribute("values_int64s", _valueMap.GetValues<ulong>().Select(item => Convert.ToInt64(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4) | ||
| { | ||
| node.AddAttribute("values_int64s", _valueMap.GetValues<uint>().Select(item => Convert.ToInt64(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else if (typeValue == NumberDataViewType.UInt16) | ||
| { | ||
| node.AddAttribute("values_int64s", _valueMap.GetValues<ushort>().Select(item => Convert.ToInt64(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else if (typeValue == NumberDataViewType.Single) | ||
| { | ||
| node.AddAttribute("values_floats", _valueMap.GetValues<float>()); | ||
| } | ||
| else if (typeValue == NumberDataViewType.Double) | ||
| { | ||
| node.AddAttribute("values_floats", _valueMap.GetValues<double>().Select(item => Convert.ToSingle(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else if (typeValue == TextDataViewType.Instance) | ||
| { | ||
| node.AddAttribute("values_strings", _valueMap.GetValues<ReadOnlyMemory<char>>()); | ||
| } | ||
| else if (typeValue == BooleanDataViewType.Instance) | ||
| { | ||
| node.AddAttribute("values_floats", _valueMap.GetValues<bool>().Select(item => Convert.ToSingle(item))); | ||
| var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); | ||
| castNode.AddAttribute("to", typeValue.RawType); | ||
| } | ||
| else | ||
| return false; | ||
| | ||
| //Unknown keys should map to 0 | ||
| node.AddAttribute("default_int64", 0); | ||
| node.AddAttribute("default_string", ""); | ||
| node.AddAttribute("default_float", 0f); | ||
| return true; | ||
| } | ||
| | ||
| protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() | ||
| { | ||
| var result = new DataViewSchema.DetachedColumn[_columns.Length]; | ||
| | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be public? Seems they're only used by Onnx conversion, so might prefer to make them private.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The valuemap class itself is only used by the ValueMapping estimator, but since the map used is not populated until runtime, I keep the method abstract. I keep the methods public so they can used by the mapper class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, ok. Well, as long as the GetKeys new method isn't accessible to end users, it's all right.