1+ using System ;
2+ using System . IO ;
3+ using System . Linq ;
4+ using Microsoft . ML ;
5+ using Microsoft . ML . Data ;
6+
7+ namespace Samples . Dynamic
8+ {
9+ public static class MNISTFullModelRetrain
10+ {
11+ /// <summary>
12+ /// Example full model retrain using the MNIST model in a ML.NET pipeline.
13+ /// </summary>
14+
15+ private static string sourceDir = Directory . GetCurrentDirectory ( ) ;
16+
17+ // Represents the path to the machinelearning directory
18+ private static string mlDir = @"..\..\..\..\" ;
19+
20+ public static void Example ( )
21+ {
22+ var mlContext = new MLContext ( seed : 1 ) ;
23+
24+ // Download training data into current directory and load into IDataView
25+ var trainData = DataDownload ( "Train-Tiny-28x28.txt" , mlContext ) ;
26+
27+ // Download testing data into current directory and load into IDataView
28+ var testData = DataDownload ( "MNIST.Test.tiny.txt" , mlContext ) ;
29+
30+ // Download the MNIST model and its variables into current directory
31+ ModelDownload ( ) ;
32+
33+ // Full model retrain pipeline
34+ var pipe = mlContext . Transforms . CopyColumns ( "Features" , "Placeholder" )
35+ . Append ( mlContext . Model . RetrainDnnModel (
36+ inputColumnNames : new [ ] { "Features" } ,
37+ outputColumnNames : new [ ] { "Prediction" } ,
38+ labelColumnName : "TfLabel" ,
39+ dnnLabel : "Label" ,
40+ optimizationOperation : "MomentumOp" ,
41+ lossOperation : "Loss" ,
42+ modelPath : "mnist_conv_model" ,
43+ metricOperation : "Accuracy" ,
44+ epoch : 10 ,
45+ learningRateOperation : "learning_rate" ,
46+ learningRate : 0.01f ,
47+ batchSize : 20 ) )
48+ . Append ( mlContext . Transforms . Concatenate ( "Features" ,
49+ "Prediction" ) )
50+ . AppendCacheCheckpoint ( mlContext )
51+ . Append ( mlContext . MulticlassClassification . Trainers . LightGbm (
52+ new Microsoft . ML . Trainers . LightGbm
53+ . LightGbmMulticlassTrainer . Options ( )
54+ {
55+ LabelColumnName = "Label" ,
56+ FeatureColumnName = "Features" ,
57+ Seed = 1 ,
58+ NumberOfThreads = 1 ,
59+ NumberOfIterations = 1
60+ } ) ) ;
61+
62+ var trainedModel = pipe . Fit ( trainData ) ;
63+ var predicted = trainedModel . Transform ( testData ) ;
64+ var metrics = mlContext . MulticlassClassification . Evaluate ( predicted ) ;
65+
66+ // Print out metrics
67+ Console . WriteLine ( ) ;
68+ Console . WriteLine ( $ "Micro-accuracy: { metrics . MicroAccuracy } , " +
69+ $ "macro-accuracy = { metrics . MacroAccuracy } ") ;
70+
71+ // Get one sample for the fully retrained model to predict on
72+ var sample = GetOneMNISTExample ( ) ;
73+
74+ // Create a prediction engine to predict on one sample
75+ var predictionEngine = mlContext . Model . CreatePredictionEngine <
76+ MNISTData , MNISTPrediction > ( trainedModel ) ;
77+
78+ var prediction = predictionEngine . Predict ( sample ) ;
79+
80+ // Print predicted labels
81+ Console . WriteLine ( "Predicted Labels: " ) ;
82+ foreach ( var pLabel in prediction . PredictedLabels )
83+ {
84+ Console . Write ( pLabel + " " ) ;
85+ }
86+
87+ // Clean up folder by deleting extra files made during retrain
88+ CleanUp ( "mnist_conv_model" ) ;
89+ }
90+
91+ // Copies data from another location into current directory
92+ // and loads it into IDataView using a TextLoader
93+ private static IDataView DataDownload ( string fileName , MLContext mlContext )
94+ {
95+ string dataPath = Path . Combine ( mlDir , "test" , "data" , fileName ) ;
96+ if ( ! File . Exists ( fileName ) )
97+ {
98+ System . IO . File . Copy ( dataPath , Path . Combine ( sourceDir , fileName ) ) ;
99+ }
100+
101+ return mlContext . Data . CreateTextLoader (
102+ new [ ]
103+ {
104+ new TextLoader . Column ( "Label" , DataKind . UInt32 ,
105+ new [ ] { new TextLoader . Range ( 0 ) } , new KeyCount ( 10 ) ) ,
106+ new TextLoader . Column ( "TfLabel" , DataKind . Int64 , 0 ) ,
107+ new TextLoader . Column ( "Placeholder" , DataKind . Single ,
108+ new [ ] { new TextLoader . Range ( 1 , 784 ) } )
109+ } ,
110+ allowSparse : true
111+ ) . Load ( fileName ) ;
112+ }
113+
114+ // Copies MNIST model folder from another location into current directory
115+ private static void ModelDownload ( )
116+ {
117+ if ( ! Directory . Exists ( Path . Combine ( sourceDir , "mnist_conv_model" ) ) )
118+ {
119+ // The original path to the MNIST model
120+ var oldModel = Path . Combine ( new [ ] { mlDir , "packages" ,
121+ "microsoft.ml.tensorflow.testmodels" , "0.0.11-test" ,
122+ "contentfiles" , "any" , "any" , "mnist_conv_model" } ) ;
123+
124+ // Create a new folder in the current directory for the MNIST model
125+ string newModel = Directory . CreateDirectory ( Path . Combine ( sourceDir ,
126+ "mnist_conv_model" ) ) . FullName ;
127+
128+ // Copy the model into the new mnist_conv_model folder
129+ System . IO . File . Copy ( Path . Combine ( oldModel , "saved_model.pb" ) ,
130+ Path . Combine ( newModel , "saved_model.pb" ) ) ;
131+
132+ // The original folder that the model variables are in.
133+ // Because the folder already exists, the "CreateDirectory" method
134+ // call creates a DirectoryInfo object for the existing folder
135+ // rather than making a new directory.
136+ var oldVariables = Directory . CreateDirectory ( Path . Combine ( oldModel ,
137+ "variables" ) ) ;
138+
139+ // Create a new folder in the new mnist_conv_model directory to
140+ // store the model variables
141+ var newVariables = Directory . CreateDirectory ( Path . Combine ( newModel ,
142+ "variables" ) ) ;
143+
144+ // Get the files in the original variables folder
145+ var variableNames = oldVariables . GetFiles ( ) ;
146+
147+ foreach ( var vName in variableNames )
148+ {
149+ // Copy each file from the original variables folder into the
150+ // new variables folder
151+ System . IO . File . Copy ( vName . FullName , Path . Combine (
152+ newVariables . FullName , vName . Name ) ) ;
153+ }
154+
155+ }
156+ }
157+ public class MNISTData
158+ {
159+ public long Label ;
160+
161+ [ VectorType ( 784 ) ]
162+ public float [ ] Placeholder ;
163+ }
164+
165+ public class MNISTPrediction
166+ {
167+ [ ColumnName ( "Score" ) ]
168+ public float [ ] PredictedLabels ;
169+ }
170+
171+ // Returns one sample
172+ private static MNISTData GetOneMNISTExample ( )
173+ {
174+ return new MNISTData ( )
175+ {
176+ Placeholder = new float [ ] { 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
177+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
178+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
179+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
180+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
181+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
182+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
183+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 3 , 18 , 18 , 18 , 126 ,
184+ 136 , 175 , 26 , 166 , 255 , 247 , 127 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
185+ 0 , 0 , 30 , 36 , 94 , 154 , 170 , 253 , 253 , 253 , 253 , 253 , 225 , 172 ,
186+ 253 , 242 , 195 , 64 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 49 , 238 ,
187+ 253 , 253 , 253 , 253 , 253 , 253 , 253 , 253 , 251 , 93 , 82 , 82 , 56 ,
188+ 39 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 18 , 219 , 253 , 253 , 253 ,
189+ 253 , 253 , 198 , 182 , 247 , 241 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
190+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 80 , 156 , 107 , 253 , 253 , 205 , 11 , 0 , 43 ,
191+ 154 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
192+ 14 , 1 , 154 , 253 , 90 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
193+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 139 , 253 , 190 , 2 , 0 , 0 , 0 , 0 ,
194+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 11 ,
195+ 190 , 253 , 70 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
196+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 35 , 241 , 225 , 160 , 108 , 1 , 0 , 0 , 0 ,
197+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 81 ,
198+ 240 , 253 , 253 , 119 , 25 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
199+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 45 , 186 , 253 , 253 , 150 , 27 , 0 , 0 ,
200+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
201+ 16 , 93 , 252 , 253 , 187 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
202+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 249 , 253 , 249 , 64 , 0 , 0 , 0 ,
203+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 46 , 130 ,
204+ 183 , 253 , 253 , 207 , 2 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
205+ 0 , 0 , 0 , 0 , 0 , 0 , 39 , 148 , 229 , 253 , 253 , 253 , 250 , 182 , 0 , 0 ,
206+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 24 , 114 , 221 ,
207+ 253 , 253 , 253 , 253 , 201 , 78 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
208+ 0 , 0 , 0 , 0 , 0 , 0 , 23 , 66 , 213 , 253 , 253 , 253 , 253 , 198 , 81 , 2 ,
209+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 18 , 171 , 219 ,
210+ 253 , 253 , 253 , 253 , 195 , 80 , 9 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
211+ 0 , 0 , 0 , 0 , 0 , 0 , 55 , 172 , 226 , 253 , 253 , 253 , 253 , 244 , 133 ,
212+ 11 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 136 ,
213+ 253 , 253 , 253 , 212 , 135 , 132 , 16 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
214+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
215+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
216+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
217+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
218+ 0 , 0 , 0 , 0 , 0 , 0 }
219+ } ;
220+ }
221+
222+ // Deletes extra variable folders produced during retrain
223+ private static void CleanUp ( string model_location )
224+ {
225+ var directories = Directory . GetDirectories ( model_location ,
226+ "variables-*" ) ;
227+ if ( directories != null && directories . Length > 0 )
228+ {
229+ var varDir = Path . Combine ( model_location , "variables" ) ;
230+ if ( Directory . Exists ( varDir ) )
231+ Directory . Delete ( varDir , true ) ;
232+ Directory . Move ( directories [ 0 ] , varDir ) ;
233+ }
234+ }
235+ }
236+ }
0 commit comments