Skip to content

Commit e6243ae

Browse files
committed
final model
1 parent 40e36c8 commit e6243ae

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

dataset.py

100755100644
File mode changed.

train.py

100755100644
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import os
3333
import time
3434
import matplotlib.pyplot as plt
35+
from tqdm import tqdm
3536
import torch
3637
from torch.utils.data import DataLoader
3738

@@ -62,7 +63,7 @@
6263

6364
NUM_EPOCHS = 6000
6465

65-
LEARNING_RATE = 1e-6
66+
LEARNING_RATE = 5e-7
6667
MOMENTUM = 0.9
6768
BATCH_SIZE = 10
6869

@@ -73,12 +74,13 @@ def train():
7374

7475
model.train()
7576

76-
for epoch in range(NUM_EPOCHS):
77+
for epoch in tqdm(range(NUM_EPOCHS)):
7778
loss_f = 0
7879
t_start = time.time()
7980
i=0
81+
# print(len(train_dataloader))
8082
for batch in train_dataloader:
81-
i+=1
83+
# print(i);i+=1
8284
input_tensor = torch.autograd.Variable(batch['image'])
8385
target_tensor = torch.autograd.Variable(batch['mask'])
8486
img = np.transpose(target_tensor.numpy(),(1,2,0))
@@ -119,6 +121,7 @@ def train():
119121
mask_dir = os.path.join(data_root, args.mask_dir)
120122

121123
CUDA = args.gpu is not None
124+
print("CUDA",CUDA)
122125
GPU_ID = args.gpu
123126
print('GPU',GPU_ID)
124127

@@ -148,13 +151,15 @@ def train():
148151
print('MODEL')
149152
model = SegNet(input_channels=NUM_INPUT_CHANNELS,
150153
output_channels=NUM_OUTPUT_CHANNELS)
154+
model.init_vgg_weigts()
151155
print('STATE_DICT')
152156
class_weights = 1.0/train_dataset.get_class_probability()
153157
print('class_weights',len(class_weights))
154158
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
155159

156160

157161
if args.checkpoint:
162+
print('Loading Checkpoint')
158163
if GPU_ID is None:
159164
model.load_state_dict(torch.load(args.checkpoint,map_location='cpu'))
160165
else:

0 commit comments

Comments
 (0)