|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +import tensorboardX |
| 5 | +import buffer_queue |
| 6 | +import collections |
| 7 | +import py_process |
| 8 | +import wrappers |
| 9 | +import config |
| 10 | +import model |
| 11 | +import time |
| 12 | +import gym |
| 13 | + |
| 14 | +flags = tf.app.flags |
| 15 | +FLAGS = tf.app.flags.FLAGS |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | +flags.DEFINE_integer('num_actors', 4, 'Number of actors.') |
| 20 | +flags.DEFINE_integer('task', -1, 'Task id. Use -1 for local training.') |
| 21 | +flags.DEFINE_integer('batch_size', 32, 'how many batch learner should be training') |
| 22 | +flags.DEFINE_integer('queue_size', 128, 'fifoqueue size') |
| 23 | +flags.DEFINE_integer('trajectory', 20, 'trajectory length') |
| 24 | +flags.DEFINE_integer('learning_frame', int(1e9), 'trajectory length') |
| 25 | + |
| 26 | +flags.DEFINE_float('start_learning_rate', 0.0006, 'start_learning_rate') |
| 27 | +flags.DEFINE_float('end_learning_rate', 0, 'end_learning_rate') |
| 28 | +flags.DEFINE_float('discount_factor', 0.99, 'discount factor') |
| 29 | +flags.DEFINE_float('entropy_coef', 0.05, 'entropy coefficient') |
| 30 | +flags.DEFINE_float('baseline_loss_coef', 0.5, 'baseline coefficient') |
| 31 | +flags.DEFINE_float('gradient_clip_norm', 40.0, 'gradient clip norm') |
| 32 | + |
| 33 | +flags.DEFINE_enum('job_name', 'learner', ['learner', 'actor'], 'Job name. Ignored when task is set to -1') |
| 34 | +flags.DEFINE_enum('reward_clipping', 'abs_one', ['abs_one', 'soft_asymmetric'], 'Reward clipping.') |
| 35 | + |
| 36 | +def main(_): |
| 37 | + |
| 38 | + local_job_device = '/job:{}/task:{}'.format(FLAGS.job_name, FLAGS.task) |
| 39 | + shared_job_device = '/job:learner/task:0' |
| 40 | + is_actor_fn = lambda i: FLAGS.job_name == 'actor' and i == FLAGS.task |
| 41 | + is_learner = FLAGS.job_name == 'learner' |
| 42 | + |
| 43 | + cluster = tf.train.ClusterSpec({ |
| 44 | + 'actor': ['localhost:{}'.format(8001+i) for i in range(FLAGS.num_actors)], |
| 45 | + 'learner': ['localhost:8000']}) |
| 46 | + |
| 47 | + server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task) |
| 48 | + |
| 49 | + filters = [shared_job_device, local_job_device] |
| 50 | + |
| 51 | + input_shape = [84, 84, 4] |
| 52 | + output_size = 18 |
| 53 | + env_name = 'BoxingDeterministic-v4' |
| 54 | + |
| 55 | + with tf.device(shared_job_device): |
| 56 | + queue = buffer_queue.FIFOQueue( |
| 57 | + FLAGS.trajectory, input_shape, output_size, |
| 58 | + FLAGS.queue_size, FLAGS.batch_size, FLAGS.num_actors) |
| 59 | + learner = model.IMPALA( |
| 60 | + trajectory=FLAGS.trajectory, |
| 61 | + input_shape=input_shape, |
| 62 | + num_action=output_size, |
| 63 | + discount_factor=FLAGS.discount_factor, |
| 64 | + start_learning_rate=FLAGS.start_learning_rate, |
| 65 | + end_learning_rate=FLAGS.end_learning_rate, |
| 66 | + learning_frame=FLAGS.learning_frame, |
| 67 | + baseline_loss_coef=FLAGS.baseline_loss_coef, |
| 68 | + entropy_coef=FLAGS.entropy_coef, |
| 69 | + gradient_clip_norm=FLAGS.gradient_clip_norm) |
| 70 | + |
| 71 | + sess = tf.Session(server.target) |
| 72 | + queue.set_session(sess) |
| 73 | + learner.set_session(sess) |
| 74 | + |
| 75 | + if is_learner: |
| 76 | + |
| 77 | + writer = tensorboardX.SummaryWriter('runs/learner') |
| 78 | + train_step = 0 |
| 79 | + |
| 80 | + while True: |
| 81 | + size = queue.get_size() |
| 82 | + if size > 3 * FLAGS.batch_size: |
| 83 | + train_step += 1 |
| 84 | + batch = queue.sample_batch() |
| 85 | + s = time.time() |
| 86 | + pi_loss, baseline_loss, entropy, learning_rate = learner.train( |
| 87 | + state=np.stack(batch.state), |
| 88 | + reward=np.stack(batch.reward), |
| 89 | + action=np.stack(batch.action), |
| 90 | + done=np.stack(batch.done), |
| 91 | + behavior_policy=np.stack(batch.behavior_policy)) |
| 92 | + writer.add_scalar('data/pi_loss', pi_loss, train_step) |
| 93 | + writer.add_scalar('data/baseline_loss', baseline_loss, train_step) |
| 94 | + writer.add_scalar('data/entropy', entropy, train_step) |
| 95 | + writer.add_scalar('data/learning_rate', learning_rate, train_step) |
| 96 | + writer.add_scalar('data/time', time.time() - s, train_step) |
| 97 | + else: |
| 98 | + |
| 99 | + trajectory_data = collections.namedtuple( |
| 100 | + 'trajectory_data', |
| 101 | + ['state', 'next_state', 'reward', 'done', 'action', 'behavior_policy']) |
| 102 | + |
| 103 | + env = wrappers.make_uint8_env(env_name) |
| 104 | + if FLAGS.task == 0: |
| 105 | + env = gym.wrappers.Monitor(env, 'save-mov', video_callable=lambda episode_id: episode_id%10==0) |
| 106 | + state = env.reset() |
| 107 | + |
| 108 | + episode = 0 |
| 109 | + score = 0 |
| 110 | + episode_step = 0 |
| 111 | + total_max_prob = 0 |
| 112 | + |
| 113 | + writer = tensorboardX.SummaryWriter('runs/actor_{}'.format(FLAGS.task)) |
| 114 | + |
| 115 | + while True: |
| 116 | + |
| 117 | + unroll_data = trajectory_data([], [], [], [], [], []) |
| 118 | + |
| 119 | + for _ in range(FLAGS.trajectory): |
| 120 | + |
| 121 | + action, behavior_policy, max_prob = learner.get_policy_and_action(state) |
| 122 | + |
| 123 | + episode_step += 1 |
| 124 | + total_max_prob += max_prob |
| 125 | + |
| 126 | + next_state, reward, done, info = env.step(action) |
| 127 | + |
| 128 | + score += reward |
| 129 | + |
| 130 | + unroll_data.state.append(state) |
| 131 | + unroll_data.next_state.append(next_state) |
| 132 | + unroll_data.reward.append(reward) |
| 133 | + unroll_data.done.append(done) |
| 134 | + unroll_data.action.append(action) |
| 135 | + unroll_data.behavior_policy.append(behavior_policy) |
| 136 | + |
| 137 | + state = next_state |
| 138 | + |
| 139 | + if done: |
| 140 | + |
| 141 | + print(episode, score) |
| 142 | + writer.add_scalar('data/prob', total_max_prob / episode_step, episode) |
| 143 | + writer.add_scalar('data/score', score, episode) |
| 144 | + writer.add_scalar('data/episode_step', episode_step, episode) |
| 145 | + episode += 1 |
| 146 | + score = 0 |
| 147 | + episode_step = 0 |
| 148 | + total_max_prob = 0 |
| 149 | + state = env.reset() |
| 150 | + |
| 151 | + queue.append_to_queue( |
| 152 | + task=FLAGS.task, unrolled_state=unroll_data.state, |
| 153 | + unrolled_next_state=unroll_data.next_state, unrolled_reward=unroll_data.reward, |
| 154 | + unrolled_done=unroll_data.done, unrolled_action=unroll_data.action, |
| 155 | + unrolled_behavior_policy=unroll_data.behavior_policy) |
| 156 | + |
| 157 | +if __name__ == '__main__': |
| 158 | + tf.app.run() |
0 commit comments