Skip to content

Commit b2f0745

Browse files
sierralee51codemzs
authored andcommitted
Full Model Retrain Sample (#4127)
* Full Model Retrain sample * removed extra lines * fixed some hardcoded code * minor changes * edited some comments * Simplified the way the ml directory path is obtained
1 parent 940598a commit b2f0745

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
using System;
2+
using System.IO;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Samples.Dynamic
8+
{
9+
public static class MNISTFullModelRetrain
10+
{
11+
/// <summary>
12+
/// Example full model retrain using the MNIST model in a ML.NET pipeline.
13+
/// </summary>
14+
15+
private static string sourceDir = Directory.GetCurrentDirectory();
16+
17+
// Represents the path to the machinelearning directory
18+
private static string mlDir = @"..\..\..\..\";
19+
20+
public static void Example()
21+
{
22+
var mlContext = new MLContext(seed: 1);
23+
24+
// Download training data into current directory and load into IDataView
25+
var trainData = DataDownload("Train-Tiny-28x28.txt", mlContext);
26+
27+
// Download testing data into current directory and load into IDataView
28+
var testData = DataDownload("MNIST.Test.tiny.txt", mlContext);
29+
30+
// Download the MNIST model and its variables into current directory
31+
ModelDownload();
32+
33+
// Full model retrain pipeline
34+
var pipe = mlContext.Transforms.CopyColumns("Features", "Placeholder")
35+
.Append(mlContext.Model.RetrainDnnModel(
36+
inputColumnNames: new[] { "Features" },
37+
outputColumnNames: new[] { "Prediction" },
38+
labelColumnName: "TfLabel",
39+
dnnLabel: "Label",
40+
optimizationOperation: "MomentumOp",
41+
lossOperation: "Loss",
42+
modelPath: "mnist_conv_model",
43+
metricOperation: "Accuracy",
44+
epoch: 10,
45+
learningRateOperation: "learning_rate",
46+
learningRate: 0.01f,
47+
batchSize: 20))
48+
.Append(mlContext.Transforms.Concatenate("Features",
49+
"Prediction"))
50+
.AppendCacheCheckpoint(mlContext)
51+
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(
52+
new Microsoft.ML.Trainers.LightGbm
53+
.LightGbmMulticlassTrainer.Options()
54+
{
55+
LabelColumnName = "Label",
56+
FeatureColumnName = "Features",
57+
Seed = 1,
58+
NumberOfThreads = 1,
59+
NumberOfIterations = 1
60+
}));
61+
62+
var trainedModel = pipe.Fit(trainData);
63+
var predicted = trainedModel.Transform(testData);
64+
var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
65+
66+
// Print out metrics
67+
Console.WriteLine();
68+
Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}, " +
69+
$"macro-accuracy = {metrics.MacroAccuracy}");
70+
71+
// Get one sample for the fully retrained model to predict on
72+
var sample = GetOneMNISTExample();
73+
74+
// Create a prediction engine to predict on one sample
75+
var predictionEngine = mlContext.Model.CreatePredictionEngine<
76+
MNISTData, MNISTPrediction>(trainedModel);
77+
78+
var prediction = predictionEngine.Predict(sample);
79+
80+
// Print predicted labels
81+
Console.WriteLine("Predicted Labels: ");
82+
foreach (var pLabel in prediction.PredictedLabels)
83+
{
84+
Console.Write(pLabel + " ");
85+
}
86+
87+
// Clean up folder by deleting extra files made during retrain
88+
CleanUp("mnist_conv_model");
89+
}
90+
91+
// Copies data from another location into current directory
92+
// and loads it into IDataView using a TextLoader
93+
private static IDataView DataDownload(string fileName, MLContext mlContext)
94+
{
95+
string dataPath = Path.Combine(mlDir, "test", "data", fileName);
96+
if (!File.Exists(fileName))
97+
{
98+
System.IO.File.Copy(dataPath, Path.Combine(sourceDir, fileName));
99+
}
100+
101+
return mlContext.Data.CreateTextLoader(
102+
new[]
103+
{
104+
new TextLoader.Column("Label", DataKind.UInt32,
105+
new[] { new TextLoader.Range(0) }, new KeyCount(10)),
106+
new TextLoader.Column("TfLabel", DataKind.Int64, 0),
107+
new TextLoader.Column("Placeholder", DataKind.Single,
108+
new[] { new TextLoader.Range(1, 784) })
109+
},
110+
allowSparse: true
111+
).Load(fileName);
112+
}
113+
114+
// Copies MNIST model folder from another location into current directory
115+
private static void ModelDownload()
116+
{
117+
if (!Directory.Exists(Path.Combine(sourceDir, "mnist_conv_model")))
118+
{
119+
// The original path to the MNIST model
120+
var oldModel = Path.Combine(new[] { mlDir, "packages",
121+
"microsoft.ml.tensorflow.testmodels", "0.0.11-test",
122+
"contentfiles", "any", "any", "mnist_conv_model" });
123+
124+
// Create a new folder in the current directory for the MNIST model
125+
string newModel = Directory.CreateDirectory(Path.Combine(sourceDir,
126+
"mnist_conv_model")).FullName;
127+
128+
// Copy the model into the new mnist_conv_model folder
129+
System.IO.File.Copy(Path.Combine(oldModel, "saved_model.pb"),
130+
Path.Combine(newModel, "saved_model.pb"));
131+
132+
// The original folder that the model variables are in.
133+
// Because the folder already exists, the "CreateDirectory" method
134+
// call creates a DirectoryInfo object for the existing folder
135+
// rather than making a new directory.
136+
var oldVariables = Directory.CreateDirectory(Path.Combine(oldModel,
137+
"variables"));
138+
139+
// Create a new folder in the new mnist_conv_model directory to
140+
// store the model variables
141+
var newVariables = Directory.CreateDirectory(Path.Combine(newModel,
142+
"variables"));
143+
144+
// Get the files in the original variables folder
145+
var variableNames = oldVariables.GetFiles();
146+
147+
foreach (var vName in variableNames)
148+
{
149+
// Copy each file from the original variables folder into the
150+
// new variables folder
151+
System.IO.File.Copy(vName.FullName, Path.Combine(
152+
newVariables.FullName, vName.Name));
153+
}
154+
155+
}
156+
}
157+
public class MNISTData
158+
{
159+
public long Label;
160+
161+
[VectorType(784)]
162+
public float[] Placeholder;
163+
}
164+
165+
public class MNISTPrediction
166+
{
167+
[ColumnName("Score")]
168+
public float[] PredictedLabels;
169+
}
170+
171+
// Returns one sample
172+
private static MNISTData GetOneMNISTExample()
173+
{
174+
return new MNISTData()
175+
{
176+
Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
177+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
178+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
179+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
180+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
181+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
182+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
183+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126,
184+
136, 175, 26, 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
185+
0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253, 225, 172,
186+
253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238,
187+
253, 253, 253, 253, 253, 253, 253, 253, 251, 93, 82, 82, 56,
188+
39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253,
189+
253, 253, 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
190+
0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0, 43,
191+
154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
192+
14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
193+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0,
194+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11,
195+
190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
196+
0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0,
197+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81,
198+
240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
199+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0,
200+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
201+
16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
202+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0,
203+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130,
204+
183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
205+
0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0,
206+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221,
207+
253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
208+
0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2,
209+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219,
210+
253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
211+
0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133,
212+
11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136,
213+
253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
214+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
215+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
216+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
217+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
218+
0, 0, 0, 0, 0, 0 }
219+
};
220+
}
221+
222+
// Deletes extra variable folders produced during retrain
223+
private static void CleanUp(string model_location)
224+
{
225+
var directories = Directory.GetDirectories(model_location,
226+
"variables-*");
227+
if (directories != null && directories.Length > 0)
228+
{
229+
var varDir = Path.Combine(model_location, "variables");
230+
if (Directory.Exists(varDir))
231+
Directory.Delete(varDir, true);
232+
Directory.Move(directories[0], varDir);
233+
}
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)