|  | 
|  | 1 | +import com.microsoft.msr.malmo.MissionSpec; | 
|  | 2 | +import org.deeplearning4j.malmo.*; | 
|  | 3 | +import org.deeplearning4j.rl4j.learning.HistoryProcessor; | 
|  | 4 | +import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; | 
|  | 5 | +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteConv; | 
|  | 6 | +import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; | 
|  | 7 | +import org.deeplearning4j.rl4j.policy.DQNPolicy; | 
|  | 8 | +import org.deeplearning4j.rl4j.util.DataManager; | 
|  | 9 | + | 
|  | 10 | +import java.io.IOException; | 
|  | 11 | +import java.util.Random; | 
|  | 12 | +import java.util.logging.Logger; | 
|  | 13 | + | 
|  | 14 | +public class MalmoExample { | 
|  | 15 | + public static QLearning.QLConfiguration MALMO_QL = new QLearning.QLConfiguration(123, //Random seed | 
|  | 16 | + 200, //Max step By epoch | 
|  | 17 | + 100000, //Max step | 
|  | 18 | + 50000, //Max size of experience replay | 
|  | 19 | + 32, //size of batches | 
|  | 20 | + 500, //target update (hard) | 
|  | 21 | + 10, //num step noop warmup | 
|  | 22 | + 0.01, //reward scaling | 
|  | 23 | + 0.99, //gamma | 
|  | 24 | + 1.0, //td-error clipping | 
|  | 25 | + 0.1f, //min epsilon | 
|  | 26 | + 10000, //num step for eps greedy anneal | 
|  | 27 | + true //double DQN | 
|  | 28 | + ); | 
|  | 29 | + | 
|  | 30 | + public static DQNFactoryStdConv.Configuration MALMO_NET = new DQNFactoryStdConv.Configuration( | 
|  | 31 | + 0.01, //learning rate | 
|  | 32 | + 0.00, //l2 regularization | 
|  | 33 | + null, // updater | 
|  | 34 | + null // Listeners | 
|  | 35 | + ); | 
|  | 36 | + | 
|  | 37 | + /* | 
|  | 38 | + * The pixel input is 320x240, but using the history processor we scale that to 160x120 | 
|  | 39 | + * and then crop out a 160x80 segment to remove pixels that aren't needed | 
|  | 40 | + */ | 
|  | 41 | + public static HistoryProcessor.Configuration MALMO_HPROC = new HistoryProcessor.Configuration(1, // Number of frames | 
|  | 42 | + 160, // Scaled width | 
|  | 43 | + 120, // Scaled height | 
|  | 44 | + 160, // Cropped width | 
|  | 45 | + 80, // Cropped height | 
|  | 46 | + 0, // X offset | 
|  | 47 | + 30, // Y offset | 
|  | 48 | + 1 // Number of frames to skip | 
|  | 49 | + ); | 
|  | 50 | + | 
|  | 51 | + public static void main(String[] args) throws IOException { | 
|  | 52 | + try { | 
|  | 53 | + malmoCliffWalk(); | 
|  | 54 | + loadMalmoCliffWalk(); | 
|  | 55 | + } catch (MalmoConnectionError e) { | 
|  | 56 | + System.out.println( | 
|  | 57 | + "To run this example, download and start Project Malmo found at https://github.com/Microsoft/malmo."); | 
|  | 58 | + } | 
|  | 59 | + } | 
|  | 60 | + | 
|  | 61 | + private static MalmoEnv createMDP() { | 
|  | 62 | + return createMDP(0); | 
|  | 63 | + } | 
|  | 64 | + | 
|  | 65 | + private static MalmoEnv createMDP(final int initialCount) { | 
|  | 66 | + MalmoActionSpaceDiscrete actionSpace = | 
|  | 67 | + new MalmoActionSpaceDiscrete("movenorth 1", "movesouth 1", "movewest 1", "moveeast 1"); | 
|  | 68 | + actionSpace.setRandomSeed(123); | 
|  | 69 | + MalmoObservationSpace observationSpace = new MalmoObservationSpacePixels(320, 240); | 
|  | 70 | + MalmoDescretePositionPolicy obsPolicy = new MalmoDescretePositionPolicy(); | 
|  | 71 | + | 
|  | 72 | + MalmoEnv mdp = new MalmoEnv("C:\\Users\\Admin\\Java-Deep-Learning-Cookbook\\09_Using RL4J for Reinforcement learning\\sourceCode\\cookbookapp\\target\\classes\\cliff_walking_rl4j.xml", actionSpace, observationSpace, obsPolicy); | 
|  | 73 | + | 
|  | 74 | + final Random r = new Random(12345); | 
|  | 75 | + | 
|  | 76 | + mdp.setResetHandler(new MalmoResetHandler() { | 
|  | 77 | + int count = initialCount; | 
|  | 78 | + | 
|  | 79 | + @Override | 
|  | 80 | + public void onReset(MalmoEnv malmoEnv) { | 
|  | 81 | + count++; | 
|  | 82 | + | 
|  | 83 | + if (count > 1000) { | 
|  | 84 | + MissionSpec mission = MalmoEnv.loadMissionXML("C:\\Users\\Admin\\Java-Deep-Learning-Cookbook\\09_Using RL4J for Reinforcement learning\\sourceCode\\cookbookapp\\target\\classes\\cliff_walking_rl4j.xml"); | 
|  | 85 | + | 
|  | 86 | + for (int x = 1; x < 4; ++x) | 
|  | 87 | + for (int z = 1; z < 13; ++z) | 
|  | 88 | + if (r.nextFloat() < 0.1) | 
|  | 89 | + mission.drawBlock(x, 45, z, "lava"); | 
|  | 90 | + | 
|  | 91 | + malmoEnv.setMission(mission); | 
|  | 92 | + } | 
|  | 93 | + } | 
|  | 94 | + }); | 
|  | 95 | + | 
|  | 96 | + return mdp; | 
|  | 97 | + } | 
|  | 98 | + | 
|  | 99 | + public static void malmoCliffWalk() throws MalmoConnectionError, IOException { | 
|  | 100 | + //record the training data in rl4j-data in a new folder (save) | 
|  | 101 | + DataManager manager = new DataManager(false); | 
|  | 102 | + | 
|  | 103 | + MalmoEnv mdp = createMDP(); | 
|  | 104 | + | 
|  | 105 | + //define the training | 
|  | 106 | + QLearningDiscreteConv<MalmoBox> dql = | 
|  | 107 | + new QLearningDiscreteConv<MalmoBox>(mdp, MALMO_NET, MALMO_HPROC, MALMO_QL, manager); | 
|  | 108 | + | 
|  | 109 | + //train | 
|  | 110 | + dql.train(); | 
|  | 111 | + | 
|  | 112 | + //get the final policy | 
|  | 113 | + DQNPolicy<MalmoBox> pol = dql.getPolicy(); | 
|  | 114 | + | 
|  | 115 | + //serialize and save (serialization showcase, but not required) | 
|  | 116 | + pol.save("cliffwalk_pixel.policy"); | 
|  | 117 | + | 
|  | 118 | + //close the mdp | 
|  | 119 | + mdp.close(); | 
|  | 120 | + } | 
|  | 121 | + | 
|  | 122 | + //showcase serialization by using the trained agent on a new similar mdp | 
|  | 123 | + public static void loadMalmoCliffWalk() throws MalmoConnectionError, IOException { | 
|  | 124 | + MalmoEnv mdp = createMDP(10000); | 
|  | 125 | + | 
|  | 126 | + //load the previous agent | 
|  | 127 | + DQNPolicy<MalmoBox> pol = DQNPolicy.load("cliffwalk_pixel.policy"); | 
|  | 128 | + | 
|  | 129 | + //evaluate the agent | 
|  | 130 | + double rewards = 0; | 
|  | 131 | + for (int i = 0; i < 10; i++) { | 
|  | 132 | + double reward = pol.play(mdp, new HistoryProcessor(MALMO_HPROC)); | 
|  | 133 | + rewards += reward; | 
|  | 134 | + Logger.getAnonymousLogger().info("Reward: " + reward); | 
|  | 135 | + } | 
|  | 136 | + | 
|  | 137 | + // Clean up | 
|  | 138 | + mdp.close(); | 
|  | 139 | + | 
|  | 140 | + Logger.getAnonymousLogger().info("average: " + rewards / 10); | 
|  | 141 | + } | 
|  | 142 | + | 
|  | 143 | + | 
|  | 144 | + | 
|  | 145 | + | 
|  | 146 | +} | 
0 commit comments