Skip to content
58 changes: 57 additions & 1 deletion src/Microsoft.ML.Api/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ public VectorTypeAttribute(params int[] dims)
/// column encapsulates.
/// </summary>
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
public sealed class ColumnAttribute : Attribute
public class ColumnAttribute : Attribute
{

public ColumnAttribute(string ordinal, string name = null)
{
Name = name;
Expand All @@ -93,6 +94,61 @@ public ColumnAttribute(string ordinal, string name = null)
public string Ordinal { get; }
}

/// <summary>
/// Describes 'Label' column with indicies.
/// </summary>
public sealed class LabelColumnAttribute : ColumnAttribute
{
public LabelColumnAttribute(string ordinal):
base(ordinal, "Label")
{
}
}

/// <summary>
/// Describes 'Features' column with indicies.
/// </summary>
public sealed class FeaturesColumnAttribute : ColumnAttribute
{
public FeaturesColumnAttribute(string ordinal) :
base(ordinal, "Features")
{
}
}

/// <summary>
/// Describes 'Weight' column with indicies.
/// </summary>
public sealed class WeightColumnAttribute : ColumnAttribute
{
public WeightColumnAttribute(string ordinal) :
base(ordinal, "Weight")
{
}
}

/// <summary>
/// Describes 'GroupId' column with indicies.
/// </summary>
public sealed class GroupColumnAttribute : ColumnAttribute
{
public GroupColumnAttribute(string ordinal) :
base(ordinal, "GroupId")
{
}
}

/// <summary>
/// Describes 'Name' column with indicies.
/// </summary>
public sealed class NameColumnAttribute : ColumnAttribute
{
public NameColumnAttribute(string ordinal) :
base(ordinal, "Name")
{
}
}

/// <summary>
/// Allows a member to specify its column name directly, as opposed to the default
/// behavior of using the member name as the column name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class HousePriceData
[Column(ordinal: "1")]
public string Date;

[Column(ordinal: "2", name: "Label")]
[LabelColumn(ordinal: "2")]
public float Price;

[Column(ordinal: "3")]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Xunit;

namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
[Fact]
public void TrainAndPredictIrisModelWithFeatureVectorTest()
{
string dataPath = GetDataPath("iris.data");

var pipeline = new LearningPipeline();

pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisDataWithFeatureVector>(useHeader: false, separator: ','));

pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field.

pipeline.Add(new StochasticDualCoordinateAscentClassifier());

PredictionModel<IrisDataWithFeatureVector, IrisPrediction> model = pipeline.Train<IrisDataWithFeatureVector, IrisPrediction>();

IrisPrediction prediction = model.Predict(new IrisDataWithFeatureVector()
{
Feat = new float[] { 5.1f, 3.3f, 1.6f, 0.2f }
});

Assert.Equal(1, prediction.PredictedLabels[0], 2);
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

prediction = model.Predict(new IrisDataWithFeatureVector()
{
Feat = new float[] { 6.4f, 3.1f, 5.5f, 2.2f }
});

Assert.Equal(0, prediction.PredictedLabels[0], 2);
Assert.Equal(0, prediction.PredictedLabels[1], 2);
Assert.Equal(1, prediction.PredictedLabels[2], 2);

prediction = model.Predict(new IrisDataWithFeatureVector()
{
Feat = new float[] { 4.4f, 3.1f, 2.5f, 1.2f }
});

Assert.Equal(.2, prediction.PredictedLabels[0], 1);
Assert.Equal(.8, prediction.PredictedLabels[1], 1);
Assert.Equal(0, prediction.PredictedLabels[2], 2);

// Note: Testing against the same data set as a simple way to test evaluation.
// This isn't appropriate in real-world scenarios.
string testDataPath = GetDataPath("iris.data");
var testData = new TextLoader(testDataPath).CreateFrom<IrisDataWithFeatureVector>(useHeader: false, separator: ',');

var evaluator = new ClassificationEvaluator();
evaluator.OutputTopKAcc = 3;
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);

Assert.Equal(.98, metrics.AccuracyMacro);
Assert.Equal(.98, metrics.AccuracyMicro, 2);
Assert.Equal(.06, metrics.LogLoss, 2);
Assert.InRange(metrics.LogLossReduction, 94, 96);
Assert.Equal(1, metrics.TopKAccuracy);

Assert.Equal(3, metrics.PerClassLogLoss.Length);
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);

ConfusionMatrix matrix = metrics.ConfusionMatrix;
Assert.Equal(3, matrix.Order);
Assert.Equal(3, matrix.ClassNames.Count);
Assert.Equal("Iris-setosa", matrix.ClassNames[0]);
Assert.Equal("Iris-versicolor", matrix.ClassNames[1]);
Assert.Equal("Iris-virginica", matrix.ClassNames[2]);

Assert.Equal(50, matrix[0, 0]);
Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]);
Assert.Equal(0, matrix[0, 1]);
Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]);
Assert.Equal(0, matrix[0, 2]);
Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]);

Assert.Equal(0, matrix[1, 0]);
Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]);
Assert.Equal(48, matrix[1, 1]);
Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]);
Assert.Equal(2, matrix[1, 2]);
Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]);

Assert.Equal(0, matrix[2, 0]);
Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]);
Assert.Equal(1, matrix[2, 1]);
Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]);
Assert.Equal(49, matrix[2, 2]);
Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]);
}

public class IrisDataWithFeatureVector
{
[FeaturesColumn("0-3")]
[VectorType(4)]
public float[] Feat;

[LabelColumn("4")]
public string IrisPlantType;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public class IrisDataWithStringLabel
[Column("3")]
public float PetalLength;

[Column("4", name: "Label")]
[LabelColumn("4")]
public string IrisPlantType;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void TrainAndPredictSentimentModelTest()

public class SentimentData
{
[Column(ordinal: "0", name: "Label")]
[LabelColumn(ordinal: "0")]
public float Sentiment;
[Column(ordinal: "1")]
public string SentimentText;
Expand Down
91 changes: 91 additions & 0 deletions test/Microsoft.ML.Tests/TextLoaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,79 @@ public void ThrowsExceptionWithPropertyName()
Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message);
}


[Fact]
public void CanSuccessfullyNamedColumns()
{
string dataPath = GetDataPath("SparseData.txt");
var loader = new Data.TextLoader(dataPath).CreateFrom<SparseInputWithNamedColumns>(useHeader: true, allowQuotedStrings: false, supportSparse: true);

using (var environment = new TlcEnvironment())
{
Experiment experiment = environment.CreateExperiment();
ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as ILearningPipelineDataStep;

experiment.Compile();
loader.SetInput(environment, experiment);
experiment.Run();

IDataLoader data = experiment.GetOutput(output.Data) as IDataLoader;
Assert.NotNull(data);

Assert.Equal(5, data.Schema.ColumnCount);
Assert.Equal("Name", data.Schema.GetColumnName(0));
Assert.Equal("GroupId", data.Schema.GetColumnName(1));
Assert.Equal("Weight", data.Schema.GetColumnName(2));
Assert.Equal("Features", data.Schema.GetColumnName(3));
Assert.Equal("Label", data.Schema.GetColumnName(4));

using (var cursor = data.GetRowCursor((a => true)))
{
var getters = new ValueGetter<float>[]{
cursor.GetGetter<float>(0),
cursor.GetGetter<float>(1),
cursor.GetGetter<float>(2),
cursor.GetGetter<float>(3),
cursor.GetGetter<float>(4)
};


Assert.True(cursor.MoveNext());

float[] targets = new float[] { 1, 2, 3, 4, 5 };
for (int i = 0; i < getters.Length; i++)
{
float value = 0;
getters[i](ref value);
Assert.Equal(targets[i], value);
}

Assert.True(cursor.MoveNext());

targets = new float[] { 0, 0, 0, 4, 5 };
for (int i = 0; i < getters.Length; i++)
{
float value = 0;
getters[i](ref value);
Assert.Equal(targets[i], value);
}

Assert.True(cursor.MoveNext());

targets = new float[] { 0, 2, 0, 0, 0 };
for (int i = 0; i < getters.Length; i++)
{
float value = 0;
getters[i](ref value);
Assert.Equal(targets[i], value);
}

Assert.False(cursor.MoveNext());
}
}

}

public class QuoteInput
{
[Column("0")]
Expand Down Expand Up @@ -268,5 +341,23 @@ public class ModelWithoutColumnAttribute
{
public string String1;
}

public class SparseInputWithNamedColumns
{
[NameColumn("0")]
public float C1;

[GroupColumn("1")]
public float C2;

[WeightColumn("2")]
public float C3;

[FeaturesColumn("3")]
public float C4;

[LabelColumn("4")]
public float C5;
}
}
}