Skip to content

Commit c7367c5

Browse files
committed
Trained actor-critic model for targets=4 and sensors=1 case, modified files in code/
1 parent 1846804 commit c7367c5

File tree

6 files changed

+28
-28
lines changed

6 files changed

+28
-28
lines changed
0 Bytes
Binary file not shown.
622 Bytes
Binary file not shown.

code/envs/robot_target_tracking_env.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self):
4848
#self.model.to(self.device)
4949

5050

51-
def env_parametrization(self, num_targets=2, num_sensors=1, target_motion_omegas=None, meas_model='range'):
51+
def env_parametrization(self, num_targets=4, num_sensors=1, target_motion_omegas=None, meas_model='range'):
5252
"""
5353
Function for parametrizing the environment
5454
"""
@@ -112,10 +112,10 @@ def env_parametrization(self, num_targets=2, num_sensors=1, target_motion_omegas
112112
self.y1_list.append(float(self.true_targets_pos[0, 1]))
113113
self.x2_list.append(float(self.true_targets_pos[1, 0]))
114114
self.y2_list.append(float(self.true_targets_pos[1, 1]))
115-
#self.x3_list.append(float(self.true_targets_pos[2, 0]))
116-
#self.y3_list.append(float(self.true_targets_pos[2, 1]))
117-
#self.x4_list.append(float(self.true_targets_pos[3, 0]))
118-
#self.y4_list.append(float(self.true_targets_pos[3, 1]))
115+
self.x3_list.append(float(self.true_targets_pos[2, 0]))
116+
self.y3_list.append(float(self.true_targets_pos[2, 1]))
117+
self.x4_list.append(float(self.true_targets_pos[3, 0]))
118+
self.y4_list.append(float(self.true_targets_pos[3, 1]))
119119

120120
self.meas_model = meas_model
121121
if self.meas_model == 'bearing':
@@ -145,10 +145,10 @@ def step(self, action, step_size):
145145
self.y1_list.append(float(self.true_targets_pos[0, 1]))
146146
self.x2_list.append(float(self.true_targets_pos[1, 0]))
147147
self.y2_list.append(float(self.true_targets_pos[1, 1]))
148-
#self.x3_list.append(float(self.true_targets_pos[2, 0]))
149-
#self.y3_list.append(float(self.true_targets_pos[2, 1]))
150-
#self.x4_list.append(float(self.true_targets_pos[3, 0]))
151-
#self.y4_list.append(float(self.true_targets_pos[3, 1]))
148+
self.x3_list.append(float(self.true_targets_pos[2, 0]))
149+
self.y3_list.append(float(self.true_targets_pos[2, 1]))
150+
self.x4_list.append(float(self.true_targets_pos[3, 0]))
151+
self.y4_list.append(float(self.true_targets_pos[3, 1]))
152152

153153
self.heatmap = torch.zeros(self.len_workspace, self.len_workspace)
154154
for index in range(0, self.num_targets):
@@ -218,10 +218,10 @@ def reset(self, **kwargs):
218218
self.y1_list.append(float(self.true_targets_pos[0, 1]))
219219
self.x2_list.append(float(self.true_targets_pos[1, 0]))
220220
self.y2_list.append(float(self.true_targets_pos[1, 1]))
221-
#self.x3_list.append(float(self.true_targets_pos[2, 0]))
222-
#self.y3_list.append(float(self.true_targets_pos[2, 1]))
223-
#self.x4_list.append(float(self.true_targets_pos[3, 0]))
224-
#self.y4_list.append(float(self.true_targets_pos[3, 1]))
221+
self.x3_list.append(float(self.true_targets_pos[2, 0]))
222+
self.y3_list.append(float(self.true_targets_pos[2, 1]))
223+
self.x4_list.append(float(self.true_targets_pos[3, 0]))
224+
self.y4_list.append(float(self.true_targets_pos[3, 1]))
225225

226226
self.heatmap = torch.zeros(self.len_workspace, self.len_workspace)
227227
for index in range(0, self.num_targets):
@@ -273,10 +273,10 @@ def render(self):
273273
plt.plot(self.x1_list[len(self.x1_list) - 1], self.y1_list[len(self.y1_list) - 1], 'o', c='b', marker='*')
274274
plt.plot(self.x2_list, self.y2_list, 'b--')
275275
plt.plot(self.x2_list[len(self.x2_list) - 1], self.y2_list[len(self.y2_list) - 1], 'o', c='b', marker='*')
276-
#plt.plot(self.x3_list, self.y3_list, 'b--')
277-
#plt.plot(self.x3_list[len(self.x3_list) - 1], self.y3_list[len(self.y3_list) - 1], 'o', c='b', marker='*')
278-
#plt.plot(self.x4_list, self.y4_list, 'b--')
279-
#plt.plot(self.x4_list[len(self.x4_list) - 1], self.y4_list[len(self.y4_list) - 1], 'o', c='b', marker='*')
276+
plt.plot(self.x3_list, self.y3_list, 'b--')
277+
plt.plot(self.x3_list[len(self.x3_list) - 1], self.y3_list[len(self.y3_list) - 1], 'o', c='b', marker='*')
278+
plt.plot(self.x4_list, self.y4_list, 'b--')
279+
plt.plot(self.x4_list[len(self.x4_list) - 1], self.y4_list[len(self.y4_list) - 1], 'o', c='b', marker='*')
280280
if(len(self.robot_movement_x) < 8):
281281
plt.plot(self.robot_movement_x, self.robot_movement_y, 'r--')
282282
else:

code/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
action_dim = env.action_space.shape[0]
3838
max_action = float(env.action_space.high[0])
3939
policy = TD3(0.0005, state_dim, 2, max_action)
40-
policy.load_actor("/home/arpitdec5/Desktop/robot_target_tracking/", "model_sensors_1_targets_2")
40+
policy.load_actor("/home/arpitdec5/Desktop/robot_target_tracking/", "model_sensors_1_targets_4")
4141

4242
# eval loop
4343
state = env.reset()

code/td3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ class Actor(nn.Module):
1515
def __init__(self, state_dim, action_dim, max_action):
1616
super(Actor, self).__init__()
1717

18-
self.l1 = nn.Linear(state_dim, 800)
19-
self.l2 = nn.Linear(800, 800)
20-
self.l3 = nn.Linear(800, action_dim)
18+
self.l1 = nn.Linear(state_dim, 1024)
19+
self.l2 = nn.Linear(1024, 1024)
20+
self.l3 = nn.Linear(1024, action_dim)
2121

2222
self.max_action = max_action
2323

@@ -33,9 +33,9 @@ class Critic(nn.Module):
3333
def __init__(self, state_dim, action_dim):
3434
super(Critic, self).__init__()
3535

36-
self.l1 = nn.Linear(state_dim + action_dim, 800)
37-
self.l2 = nn.Linear(800, 800)
38-
self.l3 = nn.Linear(800, action_dim)
36+
self.l1 = nn.Linear(state_dim + action_dim, 1024)
37+
self.l2 = nn.Linear(1024, 1024)
38+
self.l3 = nn.Linear(1024, action_dim)
3939

4040
def forward(self, state, action):
4141
q = F.relu(self.l1(torch.cat([state, action], 1)))

code/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535
# constants
3636
lr = 0.0005
37-
epochs = 2000
38-
iters = 300
37+
epochs = 4000
38+
iters = 600
3939

4040
# create TD3 object
4141
state_dim = env.observation_space.shape[0]
@@ -73,7 +73,7 @@
7373

7474
# save actor and critic models
7575
if(epoch > 1000 and epoch%10==0):
76-
policy.save("/home/arpitdec5/Desktop/robot_target_tracking/", "model_sensors_1_targets_2")
76+
policy.save("/home/arpitdec5/Desktop/robot_target_tracking/", "model_sensors_1_targets_4")
7777

7878
# print reward
7979
print()
@@ -101,4 +101,4 @@
101101
plt.plot(m_e, m_r, c='orange', label='Mean Reward')
102102
#plt.plot(g_e, g_r, c='red', label='Greedy Algorithm')
103103
plt.legend()
104-
plt.savefig("/home/arpitdec5/Desktop/robot_target_tracking/reward_sensors_1_targets_2.png")
104+
plt.savefig("/home/arpitdec5/Desktop/robot_target_tracking/reward_sensors_1_targets_4.png")

0 commit comments

Comments
 (0)