Skip to content

Commit 4fb3763

Browse files
committed
ch9-changes
1 parent 9f16681 commit 4fb3763

File tree

4 files changed

+395
-0
lines changed

4 files changed

+395
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.idea
2+
target
3+
*/target/**
4+
cookbook-app.iml
5+
cookbook-app.iws
6+
cookbook-app.ipr
7+
cookbookapp.iml
8+
dependency-reduced-pom.xml
9+
model.zip
10+
LocalExecuteExample.csv
11+
cliffwalk_pixel.policy
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project
3+
xmlns="http://maven.apache.org/POM/4.0.0"
4+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
5+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
6+
<modelVersion>4.0.0</modelVersion>
7+
<groupId>com.javadeeplearningcookbook.app</groupId>
8+
<artifactId>cookbookapp</artifactId>
9+
<version>1.0-SNAPSHOT</version>
10+
<build>
11+
<plugins>
12+
<plugin>
13+
<groupId>org.apache.maven.plugins</groupId>
14+
<artifactId>maven-compiler-plugin</artifactId>
15+
<version>3.8.0</version>
16+
<configuration>
17+
<source>1.8</source>
18+
<target>1.8</target>
19+
</configuration>
20+
</plugin>
21+
<!--<plugin>
22+
<artifactId>maven-assembly-plugin</artifactId>
23+
<configuration>
24+
<archive>
25+
<manifest>
26+
<mainClass>com.javadeeplearningcookbook.examples.CustomerRetentionPredictionExample</mainClass>
27+
</manifest>
28+
</archive>
29+
<descriptorRefs>
30+
<descriptorRef>jar-with-dependencies</descriptorRef>
31+
</descriptorRefs>
32+
</configuration>
33+
</plugin>-->
34+
<plugin>
35+
<artifactId>maven-jar-plugin</artifactId>
36+
<configuration>
37+
<archive>
38+
<manifest>
39+
<mainClass>
40+
com.javadeeplearningcookbook.examples.CustomerRetentionPredictionExample
41+
</mainClass>
42+
</manifest>
43+
</archive>
44+
</configuration>
45+
</plugin>
46+
<plugin>
47+
<groupId>org.apache.maven.plugins</groupId>
48+
<artifactId>maven-shade-plugin</artifactId>
49+
<version>3.2.0</version>
50+
<executions>
51+
<execution>
52+
<phase>package</phase>
53+
<goals>
54+
<goal>shade</goal>
55+
</goals>
56+
<configuration>
57+
<transformers>
58+
<transformer
59+
implementation="org.apache.maven.plugins.shade.resource.ApacheLicenseResourceTransformer" />
60+
<transformer
61+
implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
62+
<mainClass>com.javadeeplearningcookbook.examples.CustomerRetentionPredictionExample</mainClass>
63+
</transformer>
64+
</transformers>
65+
</configuration>
66+
</execution>
67+
</executions>
68+
</plugin>
69+
</plugins>
70+
</build>
71+
<dependencies>
72+
<dependency>
73+
<groupId>junit</groupId>
74+
<artifactId>junit</artifactId>
75+
<version>4.11</version>
76+
<scope>test</scope>
77+
</dependency>
78+
<dependency>
79+
<groupId>org.deeplearning4j</groupId>
80+
<artifactId>rl4j-core</artifactId>
81+
<version>1.0.0-beta3</version>
82+
</dependency>
83+
<dependency>
84+
<groupId>org.deeplearning4j</groupId>
85+
<artifactId>rl4j-malmo</artifactId>
86+
<version>1.0.0-beta3</version>
87+
</dependency>
88+
<dependency>
89+
<groupId>org.nd4j</groupId>
90+
<artifactId>nd4j-native-platform</artifactId>
91+
<version>1.0.0-beta3</version>
92+
</dependency>
93+
<dependency>
94+
<groupId>com.microsoft.msr.malmo</groupId>
95+
<artifactId>MalmoJavaJar</artifactId>
96+
<version>0.30.0</version>
97+
</dependency>
98+
<dependency>
99+
<groupId>org.datavec</groupId>
100+
<artifactId>datavec-api</artifactId>
101+
<version>1.0.0-beta3</version>
102+
</dependency>
103+
<!-- You need the below dependency to use CodecRecordReader-->
104+
<dependency>
105+
<groupId>org.datavec</groupId>
106+
<artifactId>datavec-data-codec</artifactId>
107+
<version>1.0.0-beta3</version>
108+
</dependency>
109+
<!-- <dependency>
110+
<groupId>org.bytedeco.javacpp-presets</groupId>
111+
<artifactId>${moduleName}-platform</artifactId>
112+
<version>${moduleVersion}-1.4.4</version>
113+
</dependency>-->
114+
<dependency>
115+
<groupId>org.bytedeco</groupId>
116+
<artifactId>javacv-platform</artifactId>
117+
<version>1.4.4</version>
118+
</dependency>
119+
<dependency>
120+
<groupId>org.bytedeco</groupId>
121+
<artifactId>javacpp</artifactId>
122+
<version>1.4.4</version>
123+
</dependency>
124+
<dependency>
125+
<groupId>org.slf4j</groupId>
126+
<artifactId>slf4j-simple</artifactId>
127+
<version>1.8.0-beta4</version>
128+
</dependency>
129+
<dependency>
130+
<groupId>org.slf4j</groupId>
131+
<artifactId>slf4j-api</artifactId>
132+
<version>1.8.0-beta4</version>
133+
</dependency>
134+
<!-- You need the below dependency to use LocalTransformExecutor-->
135+
<dependency>
136+
<groupId>org.datavec</groupId>
137+
<artifactId>datavec-local</artifactId>
138+
<version>1.0.0-beta3</version>
139+
</dependency>
140+
</dependencies>
141+
<!-- Uncomment to use snapshot version -->
142+
<!--<repositories>
143+
<repository>
144+
<id>snapshots-repo</id>
145+
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
146+
<releases>
147+
<enabled>false</enabled>
148+
</releases>
149+
<snapshots>
150+
<enabled>true</enabled>
151+
<updatePolicy>daily</updatePolicy> &lt;!&ndash; Optional, update daily &ndash;&gt;
152+
</snapshots>
153+
</repository>
154+
</repositories>-->
155+
</project>
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import com.microsoft.msr.malmo.MissionSpec;
2+
import org.deeplearning4j.malmo.*;
3+
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
4+
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
5+
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteConv;
6+
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
7+
import org.deeplearning4j.rl4j.policy.DQNPolicy;
8+
import org.deeplearning4j.rl4j.util.DataManager;
9+
10+
import java.io.IOException;
11+
import java.util.Random;
12+
import java.util.logging.Logger;
13+
14+
public class MalmoExample {
15+
public static QLearning.QLConfiguration MALMO_QL = new QLearning.QLConfiguration(123, //Random seed
16+
200, //Max step By epoch
17+
100000, //Max step
18+
50000, //Max size of experience replay
19+
32, //size of batches
20+
500, //target update (hard)
21+
10, //num step noop warmup
22+
0.01, //reward scaling
23+
0.99, //gamma
24+
1.0, //td-error clipping
25+
0.1f, //min epsilon
26+
10000, //num step for eps greedy anneal
27+
true //double DQN
28+
);
29+
30+
public static DQNFactoryStdConv.Configuration MALMO_NET = new DQNFactoryStdConv.Configuration(
31+
0.01, //learning rate
32+
0.00, //l2 regularization
33+
null, // updater
34+
null // Listeners
35+
);
36+
37+
/*
38+
* The pixel input is 320x240, but using the history processor we scale that to 160x120
39+
* and then crop out a 160x80 segment to remove pixels that aren't needed
40+
*/
41+
public static HistoryProcessor.Configuration MALMO_HPROC = new HistoryProcessor.Configuration(1, // Number of frames
42+
160, // Scaled width
43+
120, // Scaled height
44+
160, // Cropped width
45+
80, // Cropped height
46+
0, // X offset
47+
30, // Y offset
48+
1 // Number of frames to skip
49+
);
50+
51+
public static void main(String[] args) throws IOException {
52+
try {
53+
malmoCliffWalk();
54+
loadMalmoCliffWalk();
55+
} catch (MalmoConnectionError e) {
56+
System.out.println(
57+
"To run this example, download and start Project Malmo found at https://github.com/Microsoft/malmo.");
58+
}
59+
}
60+
61+
private static MalmoEnv createMDP() {
62+
return createMDP(0);
63+
}
64+
65+
private static MalmoEnv createMDP(final int initialCount) {
66+
MalmoActionSpaceDiscrete actionSpace =
67+
new MalmoActionSpaceDiscrete("movenorth 1", "movesouth 1", "movewest 1", "moveeast 1");
68+
actionSpace.setRandomSeed(123);
69+
MalmoObservationSpace observationSpace = new MalmoObservationSpacePixels(320, 240);
70+
MalmoDescretePositionPolicy obsPolicy = new MalmoDescretePositionPolicy();
71+
72+
MalmoEnv mdp = new MalmoEnv("C:\\Users\\Admin\\Java-Deep-Learning-Cookbook\\09_Using RL4J for Reinforcement learning\\sourceCode\\cookbookapp\\target\\classes\\cliff_walking_rl4j.xml", actionSpace, observationSpace, obsPolicy);
73+
74+
final Random r = new Random(12345);
75+
76+
mdp.setResetHandler(new MalmoResetHandler() {
77+
int count = initialCount;
78+
79+
@Override
80+
public void onReset(MalmoEnv malmoEnv) {
81+
count++;
82+
83+
if (count > 1000) {
84+
MissionSpec mission = MalmoEnv.loadMissionXML("C:\\Users\\Admin\\Java-Deep-Learning-Cookbook\\09_Using RL4J for Reinforcement learning\\sourceCode\\cookbookapp\\target\\classes\\cliff_walking_rl4j.xml");
85+
86+
for (int x = 1; x < 4; ++x)
87+
for (int z = 1; z < 13; ++z)
88+
if (r.nextFloat() < 0.1)
89+
mission.drawBlock(x, 45, z, "lava");
90+
91+
malmoEnv.setMission(mission);
92+
}
93+
}
94+
});
95+
96+
return mdp;
97+
}
98+
99+
public static void malmoCliffWalk() throws MalmoConnectionError, IOException {
100+
//record the training data in rl4j-data in a new folder (save)
101+
DataManager manager = new DataManager(false);
102+
103+
MalmoEnv mdp = createMDP();
104+
105+
//define the training
106+
QLearningDiscreteConv<MalmoBox> dql =
107+
new QLearningDiscreteConv<MalmoBox>(mdp, MALMO_NET, MALMO_HPROC, MALMO_QL, manager);
108+
109+
//train
110+
dql.train();
111+
112+
//get the final policy
113+
DQNPolicy<MalmoBox> pol = dql.getPolicy();
114+
115+
//serialize and save (serialization showcase, but not required)
116+
pol.save("cliffwalk_pixel.policy");
117+
118+
//close the mdp
119+
mdp.close();
120+
}
121+
122+
//showcase serialization by using the trained agent on a new similar mdp
123+
public static void loadMalmoCliffWalk() throws MalmoConnectionError, IOException {
124+
MalmoEnv mdp = createMDP(10000);
125+
126+
//load the previous agent
127+
DQNPolicy<MalmoBox> pol = DQNPolicy.load("cliffwalk_pixel.policy");
128+
129+
//evaluate the agent
130+
double rewards = 0;
131+
for (int i = 0; i < 10; i++) {
132+
double reward = pol.play(mdp, new HistoryProcessor(MALMO_HPROC));
133+
rewards += reward;
134+
Logger.getAnonymousLogger().info("Reward: " + reward);
135+
}
136+
137+
// Clean up
138+
mdp.close();
139+
140+
Logger.getAnonymousLogger().info("average: " + rewards / 10);
141+
}
142+
143+
144+
145+
146+
}

0 commit comments

Comments
 (0)