3232import os
3333import time
3434import matplotlib .pyplot as plt
35+ from tqdm import tqdm
3536import torch
3637from torch .utils .data import DataLoader
3738
6263
6364NUM_EPOCHS = 6000
6465
65- LEARNING_RATE = 1e-6
66+ LEARNING_RATE = 5e-7
6667MOMENTUM = 0.9
6768BATCH_SIZE = 10
6869
@@ -73,12 +74,13 @@ def train():
7374
7475 model .train ()
7576
76- for epoch in range (NUM_EPOCHS ):
77+ for epoch in tqdm ( range (NUM_EPOCHS ) ):
7778 loss_f = 0
7879 t_start = time .time ()
7980 i = 0
81+ # print(len(train_dataloader))
8082 for batch in train_dataloader :
81- i += 1
83+ # print(i); i+=1
8284 input_tensor = torch .autograd .Variable (batch ['image' ])
8385 target_tensor = torch .autograd .Variable (batch ['mask' ])
8486 img = np .transpose (target_tensor .numpy (),(1 ,2 ,0 ))
@@ -119,6 +121,7 @@ def train():
119121 mask_dir = os .path .join (data_root , args .mask_dir )
120122
121123 CUDA = args .gpu is not None
124+ print ("CUDA" ,CUDA )
122125 GPU_ID = args .gpu
123126 print ('GPU' ,GPU_ID )
124127
@@ -148,13 +151,15 @@ def train():
148151 print ('MODEL' )
149152 model = SegNet (input_channels = NUM_INPUT_CHANNELS ,
150153 output_channels = NUM_OUTPUT_CHANNELS )
154+ model .init_vgg_weigts ()
151155 print ('STATE_DICT' )
152156 class_weights = 1.0 / train_dataset .get_class_probability ()
153157 print ('class_weights' ,len (class_weights ))
154158 criterion = torch .nn .CrossEntropyLoss (weight = class_weights )
155159
156160
157161 if args .checkpoint :
162+ print ('Loading Checkpoint' )
158163 if GPU_ID is None :
159164 model .load_state_dict (torch .load (args .checkpoint ,map_location = 'cpu' ))
160165 else :
0 commit comments