Skip to content

Commit 501ce5d

Browse files
committed
add main.py
1 parent 946656b commit 501ce5d

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

main.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# -*- coding: utf-8 -*-
2+
# main.py
3+
# author: yangrui
4+
# description:
5+
# created: 2019-05-10T19:01:15.155Z+08:00
6+
# last-modified: 2020-03-13T10:00:30.182Z+08:00
7+
# email: yangrui19@mails.tsinghua.edu.cn
8+
9+
from global_utils import print_summary
10+
from options import parse_options
11+
from global_utils import set_global_seed, save_performance, plot_data
12+
import time
13+
from agent_env_params import design_agent_and_env
14+
from multiprocessing import Process
15+
import random
16+
17+
from environment import Environment
18+
from agent import Agent
19+
20+
21+
def run_HAC(FLAGS,env,agent, plot_figure=False, num=0):
22+
from global_utils import save_plot_figure # import here is for mutilprocessing
23+
24+
NUM_EPOCH = FLAGS.num_epochs
25+
SAVE_FREQ = FLAGS.save_freq
26+
# Print task summary
27+
print_summary(FLAGS, env)
28+
29+
if not FLAGS.test:
30+
num_episodes = FLAGS.num_exploration_episodes
31+
else:
32+
num_episodes = FLAGS.num_test_episodes
33+
NUM_EPOCH = 1 # only test 1 epoch
34+
35+
performance_list = []
36+
test_performance_list = []
37+
if FLAGS.curriculum >= 2:
38+
curriculum_epoch = NUM_EPOCH / FLAGS.curriculum
39+
assert curriculum_epoch == int(curriculum_epoch), 'NUM_EPOCH / FLAGS.curriculum should be int'
40+
41+
for epoch in range(1, NUM_EPOCH + 1):
42+
successful_episodes = 0
43+
if not FLAGS.test and FLAGS.curriculum >= 2:
44+
env.set_goal_range(env_params['curriculum_list'][int((epoch - 1) // curriculum_epoch)])
45+
46+
for episode in range(num_episodes):
47+
print("\nEpoch %d, Episode %d" % (epoch, episode))
48+
# Train for an epoch
49+
success = agent.train(env, epoch * num_episodes + episode,test=FLAGS.test)
50+
if success:
51+
print("End Goal Achieved\n")
52+
successful_episodes += 1
53+
# Save agent
54+
if epoch % SAVE_FREQ == 0 and not FLAGS.test and FLAGS.threadings == 1:
55+
agent.save_model(epoch * num_episodes)
56+
success_rate = successful_episodes / num_episodes * 100
57+
print("\nEpoch %d, Success Rate %.2f%%" % (epoch, success_rate))
58+
performance_list.append(success_rate)
59+
60+
if not FLAGS.test:
61+
success_test = 0
62+
if FLAGS.curriculum >= 2:
63+
env.set_goal_range(env_params['curriculum_list'][-1])
64+
65+
print('\ntesting for %d episodes' % (FLAGS.num_test_episodes))
66+
for episode in range(FLAGS.num_test_episodes):
67+
success = agent.train(env, episode, test=True)
68+
success_test += int(success)
69+
success_rate = success_test / FLAGS.num_test_episodes * 100
70+
print('testing accuracy: %.2f%%' % (success_rate))
71+
test_performance_list.append(success_test)
72+
73+
if plot_figure:
74+
save_plot_figure(performance_list)
75+
save_plot_figure(test_performance_list, name='test-performance.jpg')
76+
77+
save_performance(performance_list, test_performance_list, FLAGS=FLAGS, thread_num=num)
78+
if FLAGS.save_experience:
79+
agent.save_experience()
80+
81+
82+
83+
def worker(agent_params, env_params, FLAGS, i):
84+
seed = int(time.time()) + random.randint(0, 100)
85+
set_global_seed(seed)
86+
FLAGS.seed = seed
87+
env = Environment(env_params, FLAGS)
88+
agent = Agent(FLAGS, env, agent_params)
89+
run_HAC(FLAGS, env, agent, plot_figure=False, num=i)
90+
91+
92+
FLAGS = parse_options()
93+
agent_params, env_params = design_agent_and_env(FLAGS)
94+
95+
assert FLAGS.threadings >= 1, "Threadings should be more than 1!"
96+
if FLAGS.threadings == 1:
97+
seed = int(time.time()) + random.randint(0,100)
98+
set_global_seed(seed)
99+
FLAGS.seed = seed
100+
env = Environment(env_params, FLAGS)
101+
agent = Agent(FLAGS, env, agent_params)
102+
run_HAC(FLAGS, env, agent, plot_figure=True)
103+
else:
104+
# parallel run
105+
thread_list = []
106+
for i in range(FLAGS.threadings):
107+
p = Process(target=worker, args=(agent_params, env_params, FLAGS, i))
108+
p.start()
109+
thread_list.append(p)
110+
111+
for p in thread_list:
112+
p.join()
113+
114+
115+
116+
117+
118+
119+
120+
121+
122+

0 commit comments

Comments
 (0)