Skip to content

Commit f609f5a

Browse files
authored
Towards 1529: replacing the predicates with an IEnumerable on IRowToRowMapper.GetDependencies (#2504)
* towards 1529: replacing the predicates with an IEnumerable on IRowToRowMapper.GetDependencies * ISchemaBoundRowMapper does not inherit from IRowToRowMapper, but from IRowToRowMapperBase. Renaming ISchemaBoundRowMapper.GetDependencies to GetDependenciesForNewColumns
1 parent ba59364 commit f609f5a

File tree

24 files changed

+274
-193
lines changed

24 files changed

+274
-193
lines changed

src/Microsoft.ML.Core/Data/IRowToRowMapper.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using Microsoft.Data.DataView;
78

89
namespace Microsoft.ML.Data
@@ -28,11 +29,9 @@ public interface IRowToRowMapper
2829
DataViewSchema OutputSchema { get; }
2930

3031
/// <summary>
31-
/// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
32-
/// needed. The domain of the function is defined over the indices of the columns of <see cref="DataViewSchema.Count"/>
33-
/// for <see cref="InputSchema"/>.
32+
/// Given a set of columns, return the input columns that are needed to generate those output columns.
3433
/// </summary>
35-
Func<int, bool> GetDependencies(Func<int, bool> predicate);
34+
IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns);
3635

3736
/// <summary>
3837
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.

src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using Microsoft.Data.DataView;
78

@@ -55,15 +56,36 @@ internal interface ISchemaBoundMapper
5556
}
5657

5758
/// <summary>
58-
/// This interface combines <see cref="ISchemaBoundMapper"/> with <see cref="IRowToRowMapper"/>.
59+
/// This interface extends <see cref="ISchemaBoundMapper"/>.
5960
/// </summary>
6061
[BestFriend]
61-
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
62+
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
6263
{
6364
/// <summary>
64-
/// There are two schemas from <see cref="ISchemaBoundMapper"/> and <see cref="IRowToRowMapper"/>.
65-
/// Since the two parent schema's are identical in all derived classes, we merge them into <see cref="OutputSchema"/>.
65+
/// Input schema accepted.
6666
/// </summary>
67-
new DataViewSchema OutputSchema { get; }
67+
DataViewSchema InputSchema { get; }
68+
69+
/// <summary>
70+
/// Given a set of columns, from the newly generated ones, return the input columns that are needed to generate those output columns.
71+
/// </summary>
72+
IEnumerable<DataViewSchema.Column> GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns);
73+
74+
/// <summary>
75+
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
76+
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
77+
/// columns of the returned row will throw. Null predicates are disallowed.
78+
///
79+
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
80+
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
81+
/// the returned value must have the same schema as <see cref="ISchemaBoundMapper.OutputSchema"/>.
82+
///
83+
/// This method creates a live connection between the input <see cref="DataViewRow"/> and the output <see
84+
/// cref="DataViewRow"/>. In particular, when the getters of the output <see cref="DataViewRow"/> are invoked, they invoke the
85+
/// getters of the input row and base the output values on the current values of the input <see cref="DataViewRow"/>.
86+
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
87+
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
88+
/// </summary>
89+
DataViewRow GetRow(DataViewRow input, Func<int, bool> active);
6890
}
6991
}

src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
68
using Microsoft.Data.DataView;
79
using Microsoft.ML.Internal.Utilities;
810

@@ -36,12 +38,15 @@ public CompositeRowToRowMapper(DataViewSchema inputSchema, IRowToRowMapper[] map
3638
OutputSchema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].OutputSchema : inputSchema;
3739
}
3840

39-
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
41+
/// <summary>
42+
/// Given a set of columns, return the input columns that are needed to generate those output columns.
43+
/// </summary>
44+
IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> columnsNeeded)
4045
{
41-
Func<int, bool> toReturn = predicate;
4246
for (int i = InnerMappers.Length - 1; i >= 0; --i)
43-
toReturn = InnerMappers[i].GetDependencies(toReturn);
44-
return toReturn;
47+
columnsNeeded = InnerMappers[i].GetDependencies(columnsNeeded);
48+
49+
return columnsNeeded;
4550
}
4651

4752
public DataViewRow GetRow(DataViewRow input, Func<int, bool> active)
@@ -71,7 +76,11 @@ public DataViewRow GetRow(DataViewRow input, Func<int, bool> active)
7176
var deps = new Func<int, bool>[InnerMappers.Length];
7277
deps[deps.Length - 1] = active;
7378
for (int i = deps.Length - 1; i >= 1; --i)
74-
deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]);
79+
{
80+
var outputColumns = InnerMappers[i].OutputSchema.Where(c => deps[i](c.Index));
81+
var cols = InnerMappers[i].GetDependencies(outputColumns).ToArray();
82+
deps[i - 1] = c => cols.Length > 0 ? cols.Any(col => col.Index == c) : false;
83+
}
7584

7685
DataViewRow result = input;
7786
for (int i = 0; i < InnerMappers.Length; ++i)

src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper
101101
_mapperFactory = mapperFactory;
102102
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
103103
}
104-
105104
public static DataViewSchema GetOutputSchema(DataViewSchema inputSchema, IRowMapper mapper)
106105
{
107106
Contracts.CheckValue(inputSchema, nameof(inputSchema));
@@ -143,10 +142,9 @@ private protected override void SaveModel(ModelSaveContext ctx)
143142

144143
/// <summary>
145144
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
146-
/// a predicate for the needed active input columns, and a predicate for the needed active
147-
/// output columns.
145+
/// and the needed active input columns, given a predicate for the needed active output columns.
148146
/// </summary>
149-
private bool[] GetActive(Func<int, bool> predicate, out Func<int, bool> predicateInput)
147+
private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<DataViewSchema.Column> inputColumns)
150148
{
151149
int n = _bindings.Schema.Count;
152150
var active = Utils.BuildArray(n, predicate);
@@ -162,8 +160,7 @@ private bool[] GetActive(Func<int, bool> predicate, out Func<int, bool> predicat
162160
var predicateIn = _mapper.GetDependencies(predicateOut);
163161

164162
// Combine the two sets of input columns.
165-
predicateInput =
166-
col => 0 <= col && col < activeInput.Length && (activeInput[col] || predicateIn(col));
163+
inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index));
167164

168165
return active;
169166
}
@@ -192,10 +189,7 @@ private Func<int, bool> GetActiveOutputColumns(bool[] active)
192189
protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
193190
{
194191
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
195-
196-
Func<int, bool> predicateInput;
197-
var active = GetActive(predicate, out predicateInput);
198-
var inputCols = Source.Schema.Where(x => predicateInput(x.Index));
192+
var active = GetActive(predicate, out IEnumerable<DataViewSchema.Column> inputCols);
199193

200194
return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
201195
}
@@ -205,11 +199,8 @@ public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.C
205199
Host.CheckValueOrNull(rand);
206200

207201
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
202+
var active = GetActive(predicate, out IEnumerable<DataViewSchema.Column> inputCols);
208203

209-
Func<int, bool> predicateInput;
210-
var active = GetActive(predicate, out predicateInput);
211-
212-
var inputCols = Source.Schema.Where(x => predicateInput(x.Index));
213204
var inputs = Source.GetRowCursorSet(inputCols, n, rand);
214205
Host.AssertNonEmpty(inputs);
215206

@@ -243,11 +234,14 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
243234
}
244235
}
245236

246-
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
237+
/// <summary>
238+
/// Given a set of output columns, return the input columns that are needed to generate those output columns.
239+
/// </summary>
240+
IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
247241
{
248-
Func<int, bool> predicateInput;
249-
GetActive(predicate, out predicateInput);
250-
return predicateInput;
242+
var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
243+
GetActive(predicate, out var inputColumns);
244+
return inputColumns;
251245
}
252246

253247
public DataViewSchema InputSchema => Source.Schema;

src/Microsoft.ML.Data/EntryPoints/TransformModelImpl.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using System.Linq;
89
using Microsoft.Data.DataView;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.Data.IO;
@@ -222,20 +223,23 @@ public static bool IsCompositeRowToRowMapper(IDataView chain)
222223
return true;
223224
}
224225

225-
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
226+
/// <summary>
227+
/// Given a set of columns, return the input columns that are needed to generate those output columns.
228+
/// </summary>
229+
IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
226230
{
227231
_ectx.Assert(IsCompositeRowToRowMapper(_chain));
228232

229233
var transform = _chain as IDataTransform;
230-
var pred = predicate;
234+
var cols = dependingColumns;
231235
while (transform != null)
232236
{
233237
var mapper = transform as IRowToRowMapper;
234238
_ectx.AssertValue(mapper);
235-
pred = mapper.GetDependencies(pred);
239+
cols = mapper.GetDependencies(cols);
236240
transform = transform.Source as IDataTransform;
237241
}
238-
return pred;
242+
return cols;
239243
}
240244

241245
public DataViewSchema InputSchema => _rootSchema;
@@ -258,7 +262,8 @@ public DataViewRow GetRow(DataViewRow input, Func<int, bool> active)
258262
_ectx.AssertValue(mapper);
259263
mappers.Add(mapper);
260264
actives.Add(activeCur);
261-
activeCur = mapper.GetDependencies(activeCur);
265+
var activeCurCol = mapper.GetDependencies(mapper.OutputSchema.Where(col => activeCur(col.Index)));
266+
activeCur = RowCursorUtils.FromColumnsToPredicate(activeCurCol, mapper.InputSchema);
262267
transform = transform.Source as IDataTransform;
263268
}
264269
mappers.Reverse();

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,14 +615,15 @@ public Bound(IHostEnvironment env, SchemaBindableCalibratedModelParameters<TSubM
615615
OutputSchema = ScoreSchemaFactory.CreateBinaryClassificationSchema();
616616
}
617617

618-
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
618+
/// <summary>
619+
/// Given a set of columns, return the input columns that are needed to generate those output columns.
620+
/// </summary>
621+
IEnumerable<DataViewSchema.Column> ISchemaBoundRowMapper.GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
619622
{
620-
for (int i = 0; i < OutputSchema.Count; i++)
621-
{
622-
if (predicate(i))
623-
return _predictor.GetDependencies(col => true);
624-
}
625-
return col => false;
623+
if (dependingColumns.Count() > 0)
624+
return _predictor.GetDependenciesForNewColumns(OutputSchema);
625+
626+
return Enumerable.Empty<DataViewSchema.Column>();
626627
}
627628

628629
public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()

src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Linq;
78
using System.Reflection;
89
using System.Text;
910
using Microsoft.Data.DataView;
@@ -351,14 +352,12 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s
351352
/// <summary>
352353
/// Returns the input columns needed for the requested output columns.
353354
/// </summary>
354-
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
355+
IEnumerable<DataViewSchema.Column> ISchemaBoundRowMapper.GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
355356
{
356-
for (int i = 0; i < OutputSchema.Count; i++)
357-
{
358-
if (predicate(i))
359-
return col => col == FeatureColumn.Index;
360-
}
361-
return col => false;
357+
if (dependingColumns.Count() == 0)
358+
return Enumerable.Empty<DataViewSchema.Column>();
359+
360+
return Enumerable.Repeat(FeatureColumn, 1);
362361
}
363362

364363
public DataViewRow GetRow(DataViewRow input, Func<int, bool> active)

src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,11 @@ private DataViewSchema DecorateOutputSchema(DataViewSchema partialSchema, int sc
327327
return builder.ToSchema();
328328
}
329329

330-
public Func<int, bool> GetDependencies(Func<int, bool> predicate) => _mapper.GetDependencies(predicate);
330+
/// <summary>
331+
/// Given a set of columns, return the input columns that are needed to generate those output columns.
332+
/// </summary>
333+
IEnumerable<DataViewSchema.Column> ISchemaBoundRowMapper.GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
334+
=> _mapper.GetDependenciesForNewColumns(dependingColumns);
331335

332336
public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles() => _mapper.GetInputColumnRoles();
333337

0 commit comments

Comments
 (0)