Skip to content

Commit c050f5e

Browse files
authored
Merge pull request #30 from ppjerry/FixDataTransformation
fix the random transformation on both image and target
2 parents 0ba46df + 4b3580f commit c050f5e

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

Data_Loader.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torchvision
77
from skimage import io
88
from torch.utils.data import Dataset
9+
import random
10+
import numpy as np
911

1012

1113
class Images_Dataset(Dataset):
@@ -80,6 +82,7 @@ def __init__(self, images_dir, labels_dir,transformI = None, transformM = None):
8082
self.lx = torchvision.transforms.Compose([
8183
# torchvision.transforms.Resize((128,128)),
8284
torchvision.transforms.CenterCrop(96),
85+
torchvision.transforms.RandomRotation((-10,10)),
8386
torchvision.transforms.Grayscale(),
8487
torchvision.transforms.ToTensor(),
8588
#torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
@@ -93,5 +96,19 @@ def __getitem__(self, i):
9396
i1 = Image.open(self.images_dir + self.images[i])
9497
l1 = Image.open(self.labels_dir + self.labels[i])
9598

96-
return self.tx(i1), self.lx(l1)
99+
seed=np.random.randint(0,2**32) # make a seed with numpy generator
100+
101+
# apply this seed to img tranfsorms
102+
random.seed(seed)
103+
torch.manual_seed(seed)
104+
img = self.tx(i1)
105+
106+
# apply this seed to target/label tranfsorms
107+
random.seed(seed)
108+
torch.manual_seed(seed)
109+
label = self.lx(l1)
110+
111+
112+
113+
return img, label
97114

0 commit comments

Comments
 (0)