Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 21f2494

Browse files
konradczechowskilukaszkaiser
authored andcommitted
Params for mbrl with dqn. (#1592)
1 parent df4a50b commit 21f2494

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,34 @@ def dqn_original_params():
408408
return hparams
409409

410410

411+
@registry.register_hparams
412+
def dqn_guess1_params():
413+
"""Guess 1 for DQN params."""
414+
hparams = dqn_atari_base()
415+
hparams.set_hparam("num_frames", int(1e6))
416+
hparams.set_hparam("agent_update_period", 1)
417+
hparams.set_hparam("agent_target_update_period", 400)
418+
# Small replay buffer size was set for mistake, but it seems to work
419+
hparams.set_hparam("replay_buffer_replay_capacity", 10000)
420+
return hparams
421+
422+
423+
@registry.register_hparams
424+
def dqn_2m_replay_buffer_params():
425+
"""Guess 1 for DQN params, 2 milions transitions in replay buffer"""
426+
hparams = dqn_guess1_params()
427+
hparams.set_hparam("replay_buffer_replay_capacity", int(2e6) + int(1e5))
428+
return hparams
429+
430+
431+
@registry.register_hparams
432+
def dqn_10m_replay_buffer_params():
433+
"""Guess 1 for DQN params, 10 milions transitions in replay buffer"""
434+
hparams = dqn_guess1_params()
435+
hparams.set_hparam("replay_buffer_replay_capacity", int(10e6))
436+
return hparams
437+
438+
411439
def rlmf_tiny_overrides():
412440
"""Parameters to override for tiny setting excluding agent-related hparams."""
413441
return dict(

tensor2tensor/rl/trainer_model_based_params.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,33 @@ def rlmb_dqn_base():
210210
return hparams
211211

212212

213+
@registry.register_hparams
214+
def rlmb_dqn_guess1():
215+
"""rlmb_dqn guess1 params"""
216+
hparams = rlmb_dqn_base()
217+
hparams.set_hparam("base_algo_params", "dqn_guess1_params")
218+
# At the moment no other option for evaluation, so we want long rollouts to
219+
# not bias scores.
220+
hparams.set_hparam("eval_rl_env_max_episode_steps", 5000)
221+
return hparams
222+
223+
224+
@registry.register_hparams
225+
def rlmb_dqn_guess1_2m_replay_buffer():
226+
"""rlmb_dqn guess1 params"""
227+
hparams = rlmb_dqn_guess1()
228+
hparams.set_hparam("base_algo_params", "dqn_2m_replay_buffer_params")
229+
return hparams
230+
231+
232+
@registry.register_hparams
233+
def rlmb_dqn_guess1_10m_replay_buffer():
234+
"""rlmb_dqn guess1 params"""
235+
hparams = rlmb_dqn_guess1()
236+
hparams.set_hparam("base_algo_params", "dqn_10m_replay_buffer_params")
237+
return hparams
238+
239+
213240
@registry.register_hparams
214241
def rlmb_basetest():
215242
"""Base setting but quicker with only 2 epochs."""
@@ -617,6 +644,7 @@ def rlmb_dqn_tiny():
617644
hparams = rlmb_dqn_base()
618645
hparams = hparams.override_from_dict(_rlmb_tiny_overrides())
619646
update_hparams(hparams, dict(
647+
base_algo_params="dqn_guess1_params",
620648
simulated_rollout_length=2,
621649
dqn_time_limit=2,
622650
dqn_num_frames=128,

0 commit comments

Comments
 (0)