Skip to content

Commit e3bd183

Browse files
committed
update ch 12 source + change to beta3 version + add arbiter UI sample
1 parent 2cd5ac6 commit e3bd183

File tree

3 files changed

+213
-8
lines changed

3 files changed

+213
-8
lines changed

12_Benchmarking and Neural Network Optimization/sourceCode/cookbookapp/pom.xml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,33 +78,33 @@
7878
<dependency>
7979
<groupId>org.deeplearning4j</groupId>
8080
<artifactId>deeplearning4j-core</artifactId>
81-
<version>1.0.0-beta4</version>
81+
<version>1.0.0-beta3</version>
8282
</dependency>
8383
<dependency>
8484
<groupId>org.nd4j</groupId>
8585
<artifactId>nd4j-native-platform</artifactId>
86-
<version>1.0.0-beta4</version>
86+
<version>1.0.0-beta3</version>
8787
</dependency>
8888
<dependency>
8989
<groupId>org.datavec</groupId>
9090
<artifactId>datavec-api</artifactId>
91-
<version>1.0.0-beta4</version>
91+
<version>1.0.0-beta3</version>
9292
</dependency>
9393
<!-- You need the below dependency to use CodecRecordReader-->
9494
<dependency>
9595
<groupId>org.datavec</groupId>
9696
<artifactId>datavec-data-codec</artifactId>
97-
<version>1.0.0-beta4</version>
97+
<version>1.0.0-beta3</version>
9898
</dependency>
9999
<dependency>
100100
<groupId>org.deeplearning4j</groupId>
101101
<artifactId>arbiter-deeplearning4j</artifactId>
102-
<version>1.0.0-beta4</version>
102+
<version>1.0.0-beta3</version>
103103
</dependency>
104104
<dependency>
105105
<groupId>org.deeplearning4j</groupId>
106106
<artifactId>arbiter-ui_2.11</artifactId>
107-
<version>1.0.0-beta4</version>
107+
<version>1.0.0-beta3</version>
108108
</dependency>
109109
<!-- <dependency>
110110
<groupId>org.bytedeco.javacpp-presets</groupId>
@@ -135,7 +135,7 @@
135135
<dependency>
136136
<groupId>org.datavec</groupId>
137137
<artifactId>datavec-local</artifactId>
138-
<version>1.0.0-beta4</version>
138+
<version>1.0.0-beta3</version>
139139
</dependency>
140140
</dependencies>
141141
<!-- Uncomment to use snapshot version -->

12_Benchmarking and Neural Network Optimization/sourceCode/cookbookapp/src/main/java/HyperParameterTuning.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ public static void main(String[] args) {
101101

102102
IOptimizationRunner runner = new LocalOptimizationRunner(optimizationConfiguration,new MultiLayerNetworkTaskCreator());
103103
//Uncomment this if you want to store the model.
104-
// StatsStorage ss = new FileStatsStorage(new File("HyperParamOptimizationStats.dl4j"));
104+
//StatsStorage ss = new FileStatsStorage(new File("HyperParamOptimizationStats.dl4j"));
105+
//runner.addListeners(new ArbiterStatusListener(ss));
106+
//UIServer.getInstance().attach(ss);
105107
runner.addListeners(new LoggingStatusListener()); //new ArbiterStatusListener(ss)
106108
runner.execute();
107109

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import org.datavec.api.records.reader.RecordReader;
2+
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
3+
import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader;
4+
import org.datavec.api.split.FileSplit;
5+
import org.datavec.api.transform.TransformProcess;
6+
import org.datavec.api.transform.schema.Schema;
7+
import org.deeplearning4j.api.storage.StatsStorage;
8+
import org.deeplearning4j.arbiter.MultiLayerSpace;
9+
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
10+
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
11+
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
12+
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
13+
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
14+
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
15+
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
16+
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
17+
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
18+
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
19+
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
20+
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
21+
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
22+
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
23+
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
24+
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
25+
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
26+
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
27+
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
28+
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
29+
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
30+
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
31+
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
32+
import org.deeplearning4j.datasets.iterator.DataSetIteratorSplitter;
33+
import org.deeplearning4j.ui.api.UIServer;
34+
import org.deeplearning4j.ui.storage.FileStatsStorage;
35+
import org.nd4j.linalg.activations.Activation;
36+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
37+
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
38+
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
39+
import org.nd4j.linalg.io.ClassPathResource;
40+
import org.nd4j.linalg.lossfunctions.LossFunctions;
41+
42+
import java.io.File;
43+
import java.io.IOException;
44+
import java.util.Arrays;
45+
import java.util.HashMap;
46+
import java.util.Map;
47+
import java.util.Properties;
48+
import java.util.concurrent.TimeUnit;
49+
50+
public class HyperParameterTuningArbiterUiExample {
51+
public static final int labelIndex = 11; // consider index 0 to 11 for input
52+
public static final int numClasses = 1;
53+
public static void main(String[] args) {
54+
55+
ParameterSpace<Double> learningRateParam = new ContinuousParameterSpace(0.0001,0.01);
56+
ParameterSpace<Integer> layerSizeParam = new IntegerParameterSpace(5,11);
57+
MultiLayerSpace hyperParamaterSpace = new MultiLayerSpace.Builder()
58+
.updater(new AdamSpace(learningRateParam))
59+
// .weightInit(WeightInit.DISTRIBUTION).dist(new LogNormalDistribution())
60+
.addLayer(new DenseLayerSpace.Builder()
61+
.activation(Activation.RELU)
62+
.nIn(11)
63+
.nOut(layerSizeParam)
64+
.build())
65+
.addLayer(new DenseLayerSpace.Builder()
66+
.activation(Activation.RELU)
67+
.nIn(layerSizeParam)
68+
.nOut(layerSizeParam)
69+
.build())
70+
.addLayer(new OutputLayerSpace.Builder()
71+
.activation(Activation.SIGMOID)
72+
.lossFunction(LossFunctions.LossFunction.XENT)
73+
.nOut(1)
74+
.build())
75+
.build();
76+
77+
Map<String,Object> dataParams = new HashMap<>();
78+
dataParams.put("batchSize",new Integer(10));
79+
80+
Map<String,Object> commands = new HashMap<>();
81+
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, HyperParameterTuningArbiterUiExample.ExampleDataSource.class.getCanonicalName());
82+
83+
CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperParamaterSpace,dataParams);
84+
85+
Properties dataSourceProperties = new Properties();
86+
dataSourceProperties.setProperty("minibatchSize", "64");
87+
88+
ResultSaver modelSaver = new FileModelSaver("resources/");
89+
ScoreFunction scoreFunction = new EvaluationScoreFunction(org.deeplearning4j.eval.Evaluation.Metric.ACCURACY);
90+
91+
92+
TerminationCondition[] conditions = {
93+
new MaxTimeCondition(120, TimeUnit.MINUTES),
94+
new MaxCandidatesCondition(30)
95+
96+
};
97+
98+
OptimizationConfiguration optimizationConfiguration = new OptimizationConfiguration.Builder()
99+
.candidateGenerator(candidateGenerator)
100+
.dataSource(HyperParameterTuningArbiterUiExample.ExampleDataSource.class,dataSourceProperties)
101+
.modelSaver(modelSaver)
102+
.scoreFunction(scoreFunction)
103+
.terminationConditions(conditions)
104+
.build();
105+
106+
IOptimizationRunner runner = new LocalOptimizationRunner(optimizationConfiguration,new MultiLayerNetworkTaskCreator());
107+
//Uncomment this if you want to store the model.
108+
StatsStorage ss = new FileStatsStorage(new File("HyperParamOptimizationStats.dl4j"));
109+
runner.addListeners(new ArbiterStatusListener(ss));
110+
UIServer.getInstance().attach(ss);
111+
//runner.addListeners(new LoggingStatusListener()); //new ArbiterStatusListener(ss)
112+
runner.execute();
113+
114+
//Print the best hyper params
115+
116+
double bestScore = runner.bestScore();
117+
int bestCandidateIndex = runner.bestScoreCandidateIndex();
118+
int numberOfConfigsEvaluated = runner.numCandidatesCompleted();
119+
120+
String s = "Best score: " + bestScore + "\n" +
121+
"Index of model with best score: " + bestCandidateIndex + "\n" +
122+
"Number of configurations evaluated: " + numberOfConfigsEvaluated + "\n";
123+
124+
System.out.println(s);
125+
126+
}
127+
128+
129+
public static class ExampleDataSource implements DataSource {
130+
131+
private int minibatchSize;
132+
133+
public ExampleDataSource(){
134+
135+
}
136+
137+
@Override
138+
public void configure(Properties properties) {
139+
this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "16"));
140+
}
141+
142+
@Override
143+
public Object trainData() {
144+
try{
145+
DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
146+
return dataSplit(iterator).getTestIterator();
147+
}
148+
catch(Exception e){
149+
throw new RuntimeException();
150+
}
151+
}
152+
153+
@Override
154+
public Object testData() {
155+
try{
156+
DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses);
157+
return dataSplit(iterator).getTestIterator();
158+
}
159+
catch(Exception e){
160+
throw new RuntimeException();
161+
}
162+
}
163+
164+
@Override
165+
public Class<?> getDataType() {
166+
return DataSetIterator.class;
167+
}
168+
169+
public DataSetIteratorSplitter dataSplit(DataSetIterator iterator) throws IOException, InterruptedException {
170+
DataNormalization dataNormalization = new NormalizerStandardize();
171+
dataNormalization.fit(iterator);
172+
iterator.setPreProcessor(dataNormalization);
173+
DataSetIteratorSplitter splitter = new DataSetIteratorSplitter(iterator,1000,0.8);
174+
return splitter;
175+
}
176+
177+
public RecordReader dataPreprocess() throws IOException, InterruptedException {
178+
//Schema Definitions
179+
Schema schema = new Schema.Builder()
180+
.addColumnsString("RowNumber")
181+
.addColumnInteger("CustomerId")
182+
.addColumnString("Surname")
183+
.addColumnInteger("CreditScore")
184+
.addColumnCategorical("Geography", Arrays.asList("France","Spain","Germany"))
185+
.addColumnCategorical("Gender",Arrays.asList("Male","Female"))
186+
.addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited").build();
187+
188+
//Schema Transformation
189+
TransformProcess transformProcess = new TransformProcess.Builder(schema)
190+
.removeColumns("RowNumber","Surname","CustomerId")
191+
.categoricalToInteger("Gender")
192+
.categoricalToOneHot("Geography")
193+
.removeColumns("Geography[France]")
194+
.build();
195+
196+
//CSVReader - Reading from file and applying transformation
197+
RecordReader reader = new CSVRecordReader(1,',');
198+
reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile()));
199+
RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess);
200+
return transformProcessRecordReader;
201+
}
202+
}
203+
}

0 commit comments

Comments
 (0)