Skip to content

Commit d20bd5e

Browse files
committed
add data_loader and CAN model
1 parent 716a213 commit d20bd5e

File tree

7 files changed

+181
-48
lines changed

7 files changed

+181
-48
lines changed

configs/fcn8s_pascal.yml

Lines changed: 0 additions & 26 deletions
This file was deleted.

configs/frrnB_cityscapes.yml

Lines changed: 0 additions & 22 deletions
This file was deleted.

configs/isprs_potsdam.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
model:
2+
arch: CAN
3+
data:
4+
dataset: my
5+
train_split: train
6+
val_split: val
7+
img_rows: 512
8+
img_cols: 512
9+
path: your/path/to/ISPRS/Potsdam/dataset/
10+
training:
11+
train_iters: 300000
12+
batch_size: 6
13+
val_interval: 1000
14+
n_workers: 16
15+
print_interval: 50
16+
optimizer:
17+
name: 'sgd'
18+
lr: 1.0e-3
19+
weight_decay: 0.0005
20+
momentum: 0.99
21+
22+
loss:
23+
name: 'cross_entropy'
24+
size_average: True
25+
lr_schedule:
26+
resume: CAN_my_best_model.pkl

configs/isprs_vaihingen.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
model:
2+
arch: CAN
3+
data:
4+
dataset: my
5+
train_split: train
6+
val_split: val
7+
img_rows: 512
8+
img_cols: 512
9+
path: your/path/to/ISPRS/Vaihingen/dataset/
10+
training:
11+
train_iters: 300000
12+
batch_size: 6
13+
val_interval: 1000
14+
n_workers: 16
15+
print_interval: 50
16+
optimizer:
17+
name: 'sgd'
18+
lr: 1.0e-3
19+
weight_decay: 0.0005
20+
momentum: 0.99
21+
22+
loss:
23+
name: 'cross_entropy'
24+
size_average: True
25+
lr_schedule:
26+
resume: CAN_my_best_model.pkl

ptsemseg/loader/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from ptsemseg.loader.sunrgbd_loader import SUNRGBDLoader
1010
from ptsemseg.loader.mapillary_vistas_loader import mapillaryVistasLoader
1111

12+
from ptsemseg.loader.my_loader import myLoader
13+
1214

1315
def get_loader(name):
1416
"""get_loader
@@ -24,4 +26,5 @@ def get_loader(name):
2426
"nyuv2": NYUv2Loader,
2527
"sunrgbd": SUNRGBDLoader,
2628
"vistas": mapillaryVistasLoader,
29+
"my": myLoader,
2730
}[name]

ptsemseg/loader/my_loader.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import os
2+
import collections
3+
import torch
4+
import torchvision
5+
import numpy as np
6+
import scipy.misc as m
7+
import matplotlib.pyplot as plt
8+
9+
from torch.utils import data
10+
from ptsemseg.augmentations import *
11+
12+
import cv2 as cv
13+
from torchvision import transforms
14+
15+
16+
class myLoader(data.Dataset):
17+
def __init__(
18+
self,
19+
root,
20+
split="train",
21+
is_transform=False,
22+
img_size=512,
23+
augmentations=None,
24+
img_norm=True,
25+
):
26+
self.root = root
27+
self.split = split
28+
self.img_size = (
29+
img_size if isinstance(img_size, tuple) else (img_size, img_size)
30+
)
31+
self.is_transform = is_transform
32+
self.augmentations = augmentations
33+
self.img_norm = img_norm
34+
self.mean = np.array([115.3165639, 83.02458143, 81.95442675])
35+
self.n_classes = 6
36+
self.files = collections.defaultdict(list)
37+
38+
for split in ["train", "test", "val"]:
39+
file_list = os.listdir(root + "/" + split)
40+
self.files[split] = file_list
41+
42+
# self.tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.45222182, 0.32558659, 0.32138991],
43+
# [0.21074223, 0.14708663, 0.14242824])])
44+
# self.tf_no_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.45222182, 0.32558659, 0.32138991],
45+
# [1,1,1])])
46+
self.tf = transforms.ToTensor()
47+
self.tf_no_train = transforms.ToTensor()
48+
49+
def __len__(self):
50+
return len(self.files[self.split])
51+
52+
def __getitem__(self, index):
53+
img_name = self.files[self.split][index]
54+
img_path = self.root + "/" + self.split + "/" + img_name
55+
lbl_path = self.root + "/" + self.split + "_labels/" + img_name
56+
57+
img = cv.cvtColor(cv.imread(img_path, -1), cv.COLOR_BGR2RGB)
58+
lbl = cv.imread(lbl_path, -1)
59+
# im = Image.open(im_path)
60+
# lbl = Image.open(lbl_path)
61+
62+
if self.augmentations is not None:
63+
img, lbl = self.augmentations(img, lbl)
64+
65+
if self.is_transform:
66+
img, lbl = self.transform(img, lbl)
67+
68+
return img, lbl
69+
70+
def transform(self, img, lbl):
71+
if self.img_size == ('same', 'same'):
72+
pass
73+
else:
74+
#opencv resize,(width,heigh)
75+
img=cv.resize(img,(self.img_size[1],self.img_size[0]))
76+
lbl = cv.resize(lbl, (self.img_size[1], self.img_size[0]))
77+
78+
# img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode
79+
# lbl = lbl.resize((self.img_size[0], self.img_size[1]))
80+
if self.split=="train":
81+
img = self.tf(img)
82+
else:
83+
img=self.tf_no_train(img)
84+
lbl = torch.from_numpy(lbl).long()
85+
return img, lbl
86+
87+
def decode_segmap(self, temp, plot=False):
88+
Imps = [255,255,255]
89+
Building = [0,0,255]
90+
Lowvg = [0,255,255]
91+
Tree = [0,255,0]
92+
Car= [255,255,0]
93+
bg = [255,0,0]
94+
95+
label_colours = np.array(
96+
[
97+
Imps,
98+
Building,
99+
Lowvg,
100+
Tree,
101+
Car,
102+
bg,
103+
]
104+
)
105+
r = temp.copy()
106+
g = temp.copy()
107+
b = temp.copy()
108+
for l in range(0, self.n_classes):
109+
r[temp == l] = label_colours[l, 0]
110+
g[temp == l] = label_colours[l, 1]
111+
b[temp == l] = label_colours[l, 2]
112+
# rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
113+
rgb = np.zeros((temp.shape[0], temp.shape[1], 3), dtype=np.uint8)
114+
# rgb[:, :, 0] = r / 255.0
115+
# rgb[:, :, 1] = g / 255.0
116+
# rgb[:, :, 2] = b / 255.0
117+
rgb[:, :, 0] = r
118+
rgb[:, :, 1] = g
119+
rgb[:, :, 2] = b
120+
return rgb
121+
122+

ptsemseg/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ptsemseg.models.icnet import icnet
99
from ptsemseg.models.linknet import linknet
1010
from ptsemseg.models.frrn import frrn
11+
from ptsemseg.models.CAN import CAN50
1112

1213

1314
def get_model(model_dict, n_classes, version=None):
@@ -41,6 +42,8 @@ def get_model(model_dict, n_classes, version=None):
4142
elif name == "icnetBN":
4243
model = model(n_classes=n_classes, **param_dict)
4344

45+
elif name == "CAN":
46+
model = model(n_classes=n_classes, **param_dict)
4447
else:
4548
model = model(n_classes=n_classes, **param_dict)
4649

@@ -61,6 +64,7 @@ def _get_model_instance(name):
6164
"linknet": linknet,
6265
"frrnA": frrn,
6366
"frrnB": frrn,
67+
"CAN":CAN50,
6468
}[name]
6569
except:
6670
raise ("Model {} not available".format(name))

0 commit comments

Comments
 (0)