|  | 
|  | 1 | +import org.datavec.api.records.reader.RecordReader; | 
|  | 2 | +import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | 
|  | 3 | +import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader; | 
|  | 4 | +import org.datavec.api.split.FileSplit; | 
|  | 5 | +import org.datavec.api.transform.TransformProcess; | 
|  | 6 | +import org.datavec.api.transform.schema.Schema; | 
|  | 7 | +import org.deeplearning4j.api.storage.StatsStorage; | 
|  | 8 | +import org.deeplearning4j.arbiter.MultiLayerSpace; | 
|  | 9 | +import org.deeplearning4j.arbiter.conf.updater.AdamSpace; | 
|  | 10 | +import org.deeplearning4j.arbiter.layers.DenseLayerSpace; | 
|  | 11 | +import org.deeplearning4j.arbiter.layers.OutputLayerSpace; | 
|  | 12 | +import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; | 
|  | 13 | +import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; | 
|  | 14 | +import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; | 
|  | 15 | +import org.deeplearning4j.arbiter.optimize.api.data.DataSource; | 
|  | 16 | +import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; | 
|  | 17 | +import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; | 
|  | 18 | +import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; | 
|  | 19 | +import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; | 
|  | 20 | +import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; | 
|  | 21 | +import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; | 
|  | 22 | +import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; | 
|  | 23 | +import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; | 
|  | 24 | +import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; | 
|  | 25 | +import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; | 
|  | 26 | +import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; | 
|  | 27 | +import org.deeplearning4j.arbiter.saver.local.FileModelSaver; | 
|  | 28 | +import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction; | 
|  | 29 | +import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; | 
|  | 30 | +import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener; | 
|  | 31 | +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | 
|  | 32 | +import org.deeplearning4j.datasets.iterator.DataSetIteratorSplitter; | 
|  | 33 | +import org.deeplearning4j.ui.api.UIServer; | 
|  | 34 | +import org.deeplearning4j.ui.storage.FileStatsStorage; | 
|  | 35 | +import org.nd4j.linalg.activations.Activation; | 
|  | 36 | +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | 
|  | 37 | +import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | 
|  | 38 | +import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; | 
|  | 39 | +import org.nd4j.linalg.io.ClassPathResource; | 
|  | 40 | +import org.nd4j.linalg.lossfunctions.LossFunctions; | 
|  | 41 | + | 
|  | 42 | +import java.io.File; | 
|  | 43 | +import java.io.IOException; | 
|  | 44 | +import java.util.Arrays; | 
|  | 45 | +import java.util.HashMap; | 
|  | 46 | +import java.util.Map; | 
|  | 47 | +import java.util.Properties; | 
|  | 48 | +import java.util.concurrent.TimeUnit; | 
|  | 49 | + | 
|  | 50 | +public class HyperParameterTuningArbiterUiExample { | 
|  | 51 | + public static final int labelIndex = 11; // consider index 0 to 11 for input | 
|  | 52 | + public static final int numClasses = 1; | 
|  | 53 | + public static void main(String[] args) { | 
|  | 54 | + | 
|  | 55 | + ParameterSpace<Double> learningRateParam = new ContinuousParameterSpace(0.0001,0.01); | 
|  | 56 | + ParameterSpace<Integer> layerSizeParam = new IntegerParameterSpace(5,11); | 
|  | 57 | + MultiLayerSpace hyperParamaterSpace = new MultiLayerSpace.Builder() | 
|  | 58 | + .updater(new AdamSpace(learningRateParam)) | 
|  | 59 | + // .weightInit(WeightInit.DISTRIBUTION).dist(new LogNormalDistribution()) | 
|  | 60 | + .addLayer(new DenseLayerSpace.Builder() | 
|  | 61 | + .activation(Activation.RELU) | 
|  | 62 | + .nIn(11) | 
|  | 63 | + .nOut(layerSizeParam) | 
|  | 64 | + .build()) | 
|  | 65 | + .addLayer(new DenseLayerSpace.Builder() | 
|  | 66 | + .activation(Activation.RELU) | 
|  | 67 | + .nIn(layerSizeParam) | 
|  | 68 | + .nOut(layerSizeParam) | 
|  | 69 | + .build()) | 
|  | 70 | + .addLayer(new OutputLayerSpace.Builder() | 
|  | 71 | + .activation(Activation.SIGMOID) | 
|  | 72 | + .lossFunction(LossFunctions.LossFunction.XENT) | 
|  | 73 | + .nOut(1) | 
|  | 74 | + .build()) | 
|  | 75 | + .build(); | 
|  | 76 | + | 
|  | 77 | + Map<String,Object> dataParams = new HashMap<>(); | 
|  | 78 | + dataParams.put("batchSize",new Integer(10)); | 
|  | 79 | + | 
|  | 80 | + Map<String,Object> commands = new HashMap<>(); | 
|  | 81 | + commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, HyperParameterTuningArbiterUiExample.ExampleDataSource.class.getCanonicalName()); | 
|  | 82 | + | 
|  | 83 | + CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperParamaterSpace,dataParams); | 
|  | 84 | + | 
|  | 85 | + Properties dataSourceProperties = new Properties(); | 
|  | 86 | + dataSourceProperties.setProperty("minibatchSize", "64"); | 
|  | 87 | + | 
|  | 88 | + ResultSaver modelSaver = new FileModelSaver("resources/"); | 
|  | 89 | + ScoreFunction scoreFunction = new EvaluationScoreFunction(org.deeplearning4j.eval.Evaluation.Metric.ACCURACY); | 
|  | 90 | + | 
|  | 91 | + | 
|  | 92 | + TerminationCondition[] conditions = { | 
|  | 93 | + new MaxTimeCondition(120, TimeUnit.MINUTES), | 
|  | 94 | + new MaxCandidatesCondition(30) | 
|  | 95 | + | 
|  | 96 | + }; | 
|  | 97 | + | 
|  | 98 | + OptimizationConfiguration optimizationConfiguration = new OptimizationConfiguration.Builder() | 
|  | 99 | + .candidateGenerator(candidateGenerator) | 
|  | 100 | + .dataSource(HyperParameterTuningArbiterUiExample.ExampleDataSource.class,dataSourceProperties) | 
|  | 101 | + .modelSaver(modelSaver) | 
|  | 102 | + .scoreFunction(scoreFunction) | 
|  | 103 | + .terminationConditions(conditions) | 
|  | 104 | + .build(); | 
|  | 105 | + | 
|  | 106 | + IOptimizationRunner runner = new LocalOptimizationRunner(optimizationConfiguration,new MultiLayerNetworkTaskCreator()); | 
|  | 107 | + //Uncomment this if you want to store the model. | 
|  | 108 | + StatsStorage ss = new FileStatsStorage(new File("HyperParamOptimizationStats.dl4j")); | 
|  | 109 | + runner.addListeners(new ArbiterStatusListener(ss)); | 
|  | 110 | + UIServer.getInstance().attach(ss); | 
|  | 111 | + //runner.addListeners(new LoggingStatusListener()); //new ArbiterStatusListener(ss) | 
|  | 112 | + runner.execute(); | 
|  | 113 | + | 
|  | 114 | + //Print the best hyper params | 
|  | 115 | + | 
|  | 116 | + double bestScore = runner.bestScore(); | 
|  | 117 | + int bestCandidateIndex = runner.bestScoreCandidateIndex(); | 
|  | 118 | + int numberOfConfigsEvaluated = runner.numCandidatesCompleted(); | 
|  | 119 | + | 
|  | 120 | + String s = "Best score: " + bestScore + "\n" + | 
|  | 121 | + "Index of model with best score: " + bestCandidateIndex + "\n" + | 
|  | 122 | + "Number of configurations evaluated: " + numberOfConfigsEvaluated + "\n"; | 
|  | 123 | + | 
|  | 124 | + System.out.println(s); | 
|  | 125 | + | 
|  | 126 | + } | 
|  | 127 | + | 
|  | 128 | + | 
|  | 129 | + public static class ExampleDataSource implements DataSource { | 
|  | 130 | + | 
|  | 131 | + private int minibatchSize; | 
|  | 132 | + | 
|  | 133 | + public ExampleDataSource(){ | 
|  | 134 | + | 
|  | 135 | + } | 
|  | 136 | + | 
|  | 137 | + @Override | 
|  | 138 | + public void configure(Properties properties) { | 
|  | 139 | + this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "16")); | 
|  | 140 | + } | 
|  | 141 | + | 
|  | 142 | + @Override | 
|  | 143 | + public Object trainData() { | 
|  | 144 | + try{ | 
|  | 145 | + DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); | 
|  | 146 | + return dataSplit(iterator).getTestIterator(); | 
|  | 147 | + } | 
|  | 148 | + catch(Exception e){ | 
|  | 149 | + throw new RuntimeException(); | 
|  | 150 | + } | 
|  | 151 | + } | 
|  | 152 | + | 
|  | 153 | + @Override | 
|  | 154 | + public Object testData() { | 
|  | 155 | + try{ | 
|  | 156 | + DataSetIterator iterator = new RecordReaderDataSetIterator(dataPreprocess(),minibatchSize,labelIndex,numClasses); | 
|  | 157 | + return dataSplit(iterator).getTestIterator(); | 
|  | 158 | + } | 
|  | 159 | + catch(Exception e){ | 
|  | 160 | + throw new RuntimeException(); | 
|  | 161 | + } | 
|  | 162 | + } | 
|  | 163 | + | 
|  | 164 | + @Override | 
|  | 165 | + public Class<?> getDataType() { | 
|  | 166 | + return DataSetIterator.class; | 
|  | 167 | + } | 
|  | 168 | + | 
|  | 169 | + public DataSetIteratorSplitter dataSplit(DataSetIterator iterator) throws IOException, InterruptedException { | 
|  | 170 | + DataNormalization dataNormalization = new NormalizerStandardize(); | 
|  | 171 | + dataNormalization.fit(iterator); | 
|  | 172 | + iterator.setPreProcessor(dataNormalization); | 
|  | 173 | + DataSetIteratorSplitter splitter = new DataSetIteratorSplitter(iterator,1000,0.8); | 
|  | 174 | + return splitter; | 
|  | 175 | + } | 
|  | 176 | + | 
|  | 177 | + public RecordReader dataPreprocess() throws IOException, InterruptedException { | 
|  | 178 | + //Schema Definitions | 
|  | 179 | + Schema schema = new Schema.Builder() | 
|  | 180 | + .addColumnsString("RowNumber") | 
|  | 181 | + .addColumnInteger("CustomerId") | 
|  | 182 | + .addColumnString("Surname") | 
|  | 183 | + .addColumnInteger("CreditScore") | 
|  | 184 | + .addColumnCategorical("Geography", Arrays.asList("France","Spain","Germany")) | 
|  | 185 | + .addColumnCategorical("Gender",Arrays.asList("Male","Female")) | 
|  | 186 | + .addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited").build(); | 
|  | 187 | + | 
|  | 188 | + //Schema Transformation | 
|  | 189 | + TransformProcess transformProcess = new TransformProcess.Builder(schema) | 
|  | 190 | + .removeColumns("RowNumber","Surname","CustomerId") | 
|  | 191 | + .categoricalToInteger("Gender") | 
|  | 192 | + .categoricalToOneHot("Geography") | 
|  | 193 | + .removeColumns("Geography[France]") | 
|  | 194 | + .build(); | 
|  | 195 | + | 
|  | 196 | + //CSVReader - Reading from file and applying transformation | 
|  | 197 | + RecordReader reader = new CSVRecordReader(1,','); | 
|  | 198 | + reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile())); | 
|  | 199 | + RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess); | 
|  | 200 | + return transformProcessRecordReader; | 
|  | 201 | + } | 
|  | 202 | + } | 
|  | 203 | +} | 
0 commit comments