Skip to content

Commit bef3644

Browse files
committed
added scoring
1 parent a98cdc2 commit bef3644

File tree

6 files changed

+179
-24
lines changed

6 files changed

+179
-24
lines changed

evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def add_arguments():
5050
parser.add_argument("--mass_violation", type=bool, default=True)
5151
parser.add_argument("--factor", type=int, default=2)
5252
parser.add_argument("--time_sr", default=False)
53+
parser.add_argument("--constraints_window_size", default=4, type=int)
5354

5455
return parser.parse_args()
5556

@@ -81,7 +82,7 @@ def main(args):
8182

8283
#add_string = args.model_id + '_evaluate_training'
8384

84-
create_report(scores, args, add_string)
85+
#create_report(scores, args, add_string)
8586
main_scoring(args)
8687

8788
def load_weights(model, model_id):

main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def add_arguments():
77
parser = argparse.ArgumentParser()
88
parser.add_argument("--dataset", default='dataset34', help="choose a data set to use")
9-
parser.add_argument("--scale", default='minmax_fixed', help="standard, minmax, none")
9+
parser.add_argument("--scale", default='minmax', help="standard, minmax, none")
1010
parser.add_argument("--model", default='motifnet_learnable')
1111
parser.add_argument("--model_id", default='motifbased_exp_ood_softmaxfirst_test')
1212
parser.add_argument("--number_channels", default=64, type=int)
@@ -34,6 +34,7 @@ def add_arguments():
3434
parser.add_argument("--test", default=False, type=bool)
3535
parser.add_argument("--l2_reg", default=False, type=bool)
3636
parser.add_argument("--dim_channels", default=1, type=int)
37+
parser.add_argument("--constraints_window_size", default=4, type=int)
3738

3839

3940
return parser.parse_args()

models.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def forward(self):
4343
# (shape[0], shape[1], shape[2])
4444

4545
if self.positivity:
46-
basis = torch.exp(basis)
47-
#basis = torch.nn.functional.softplus(basis)
46+
#basis = torch.exp(basis)
47+
basis = torch.nn.functional.softplus(basis)
4848

4949
# Normalization
5050
sums = torch.sum(basis, dim=(1, 2))
@@ -364,6 +364,7 @@ def forward(self, x):
364364
out = out/sum_c
365365
out = self.mult_in(out, x[:,0,...])
366366

367+
367368
#print('input', x[0,0,0,0,0])
368369
#print('after mult', out[0,:,0,0].mean())
369370
# (n_batch, 16, 32, 32)
@@ -398,7 +399,7 @@ class ResidualBlock(nn.Module):
398399
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
399400
super(ResidualBlock, self).__init__()
400401
self.conv1 = conv3x3(in_channels, out_channels, stride)
401-
self.relu = nn.ReLU(inplace=False)
402+
self.relu = nn.ReLU(inplace=True)
402403
self.conv2 = conv3x3(out_channels, out_channels)
403404

404405
def forward(self, x):
@@ -444,6 +445,7 @@ def __init__(self, upsampling_factor):
444445
self.pool = torch.nn.AvgPool2d(kernel_size=upsampling_factor)
445446
self.upsampling_factor = upsampling_factor
446447
def forward(self, y, lr):
448+
y = y.clone()
447449
sum_y = self.pool(y)
448450
out =y+ torch.kron(lr-sum_y, torch.ones((self.upsampling_factor,self.upsampling_factor)).to('cuda'))
449451
return out
@@ -461,15 +463,18 @@ def forward(self, y, lr):
461463
return out
462464

463465
class SoftmaxConstraints(nn.Module):
464-
def __init__(self, upsampling_factor, exp_factor=1):
466+
def __init__(self, upsampling_factor, cwindow_size, exp_factor=1):
465467
super(SoftmaxConstraints, self).__init__()
466-
self.pool = torch.nn.AvgPool2d(kernel_size=upsampling_factor)
468+
self.pool = torch.nn.AvgPool2d(kernel_size=cwindow_size)
469+
self.lr_pool = torch.nn.AvgPool2d(kernel_size=int(cwindow_size/upsampling_factor))
467470
self.upsampling_factor = upsampling_factor
471+
self.cwindow_size = cwindow_size
468472
self.exp_factor = exp_factor
469473
def forward(self, y, lr):
470474
y = torch.exp(y*self.exp_factor)
471475
sum_y = self.pool(y)
472-
out = y*torch.kron(lr*1/sum_y, torch.ones((self.upsampling_factor,self.upsampling_factor)).to('cuda'))
476+
lr_sum = self.lr_pool(lr)
477+
out = y*torch.kron(lr_sum*1/sum_y, torch.ones((self.cwindow_size,self.cwindow_size)).to('cuda'))
473478
return out
474479

475480

@@ -624,33 +629,33 @@ def forward(self, x, z=None):
624629
return out
625630

626631
class ResNet2(nn.Module):
627-
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1):
632+
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, cwindow_size=2):
628633
super(ResNet2, self).__init__()
629634
# First layer
630635
if noise:
631636
self.conv_trans0 = nn.ConvTranspose2d(100, 1, kernel_size=(32,32), padding=0, stride=1)
632-
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False))
637+
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
633638
else:
634-
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False))
639+
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
635640
#Residual Blocks
636641
self.res_blocks = nn.ModuleList()
637642
for k in range(number_residual_blocks):
638643
self.res_blocks.append(ResidualBlock(number_channels, number_channels))
639644
# Second conv layer post residual blocks
640645
self.conv2 = nn.Sequential(
641-
nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False))
646+
nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
642647
# Upsampling layers
643648
self.upsampling = nn.ModuleList()
644649
for k in range(int(np.rint(np.log2(upsampling_factor)))):
645650
self.upsampling.append(nn.ConvTranspose2d(number_channels, number_channels, kernel_size=2, padding=0, stride=2) )
646651
# Next layer after upper sampling
647-
self.conv3 = nn.Sequential(nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False))
652+
self.conv3 = nn.Sequential(nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
648653
# Final output layer
649654
self.conv4 = nn.Conv2d(number_channels, dim, kernel_size=1, stride=1, padding=0)
650655
#optional renomralization layer
651656
self.is_constraints = False
652657
if constraints == 'softmax':
653-
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor)
658+
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor, cwindow_size=cwindow_size)
654659
self.is_constraints = True
655660
elif constraints == 'enforce_op':
656661
self.constraints = EnforcementOperator(upsampling_factor=upsampling_factor)

scoring.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from skimage import transform
2+
import numpy as np
3+
import torch
4+
from PIL import Image
5+
import torch.nn as nn
6+
import torchgeometry as tgm
7+
import argparse
8+
from torch.utils.data import DataLoader, TensorDataset
9+
import csv
10+
11+
def add_arguments():
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument("--dataset", default='dataset34', help="choose a data set to use")
14+
parser.add_argument("--model", default='resnet2_ood_softmax_test')
15+
parser.add_argument("--model_id", default='dataset34_resnet2_ood_softmax_test_test')
16+
parser.add_argument("--time", default=True)
17+
parser.add_argument("--nn", default=True)
18+
parser.add_argument("--test", default=True)
19+
parser.add_argument("--time_steps", type=int, default=1)
20+
parser.add_argument("--mass_violation", type=bool, default=True)
21+
parser.add_argument("--factor", type=int, default=4)
22+
parser.add_argument("--time_sr", default=False)
23+
#args for model loading
24+
return parser.parse_args()
25+
26+
def main_scoring(args):
27+
#n = 24
28+
input_val = torch.load('./data/test/'+args.dataset+'/input_test.pt')
29+
target_val = torch.load('./data/test/'+args.dataset+'/target_test.pt')
30+
#target_val = torch.load('./data/test/dataset28/target_test.pt')
31+
val_data = TensorDataset(input_val, target_val)
32+
pred = np.zeros(target_val.shape)
33+
print(pred.shape)
34+
factor = args.factor
35+
mse = 0
36+
mae = 0
37+
ssim = 0
38+
mass_violation = 0
39+
l2_crit = nn.MSELoss()
40+
l1_crit = nn.L1Loss()
41+
ssim_criterion = tgm.losses.SSIM(window_size=11, max_val=130.83, reduction='mean')
42+
if args.nn:
43+
44+
pred = torch.load('./data/prediction/'+args.dataset+'_'+args.model_id+'_test.pt')
45+
pred = pred.detach().cpu().numpy()
46+
print(pred.shape)
47+
for i,(lr, hr) in enumerate(val_data):
48+
im = lr.numpy()
49+
if args.time:
50+
#print(hr.shape)
51+
for j in range(args.time_steps):
52+
if args.model == 'bilinear':
53+
pred[i,j,0,:,:] = np.array(Image.fromarray(im[j,0,...]).resize((4*lr.shape[2],4*lr.shape[2]), Image.BILINEAR))
54+
elif args.model == 'bicubic':
55+
pred[i,j,0,:,:] = np.array(Image.fromarray(im[j,0,...]).resize((factor*lr.shape[2],factor*lr.shape[2]), Image.BICUBIC))
56+
elif args.model == 'bicubic_frame':
57+
if j == 0:
58+
pred[i,j,0,:,:] = np.array(Image.fromarray(im[0,0,...]).resize((4*lr.shape[2],4*lr.shape[2]), Image.BICUBIC))
59+
elif j == 2:
60+
pred[i,j,0,:,:] = np.array(Image.fromarray(im[1,0,...]).resize((4*lr.shape[2],4*lr.shape[2]), Image.BICUBIC))
61+
else:
62+
pred[i,j,0,:,:] =np.array(Image.fromarray(0.5*(im[1,0,...]+im[0,0,...])).resize((4*lr.shape[2],4*lr.shape[2]), Image.BICUBIC))
63+
elif args.model == 'kronecker':
64+
pred[i,j,0,:,:] = np.kron(im[j,0,...], np.ones((4,4)))
65+
elif args.model=='frame_inter':
66+
print(pred.shape, im.shape)
67+
pred[i,j,0,... ] = 0.5*(im[0,0,...]+im[1,0,...])
68+
'''
69+
mse_loss = l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
70+
71+
print(i, mse_loss)'''
72+
mse += l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
73+
mae += l1_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
74+
ssim += ssim_criterion(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...]).item()
75+
if args.mass_violation:
76+
if args.time_sr:
77+
if j==0:
78+
mass_violation += np.mean( np.abs(transform.downscale_local_mean(pred[i,j,...], (1,args.factor,args.factor)) -im[0,...]))
79+
elif j==2:
80+
mass_violation += np.mean( np.abs(transform.downscale_local_mean(pred[i,j,...], (1,args.factor,args.factor)) -im[1,...]))
81+
else:
82+
mass_violation += np.mean( np.abs(transform.downscale_local_mean(pred[i,j,...], (1,args.factor,args.factor)) -im[j,...]))
83+
#print(l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item())
84+
elif args.model=='frame_inter':
85+
pred[i,... ] = 0.5*(im[0:1,...]+im[1:2,...])
86+
mse += l2_crit(torch.Tensor(pred[i,...]), hr).item()
87+
mae += l1_crit(torch.Tensor(pred[i,...]), hr).item()
88+
ssim += ssim_criterion(torch.Tensor(pred[i,...]), hr).item()
89+
#print(l1_crit(torch.Tensor(pred[i,...]), hr).item())
90+
else:
91+
if args.model == 'bilinear':
92+
pred[i,:,:] = np.array(Image.fromarray(im).resize((4*lr.shape[1],4*lr.shape[1]), Image.BILINEAR))
93+
elif args.model == 'bicubic':
94+
pred[i,:,:] = np.array(Image.fromarray(im).resize((4*lr.shape[1],4*lr.shape[1]), Image.BICUBIC))
95+
#elif args.model == 'kronecker':
96+
97+
mse += l2_crit(torch.Tensor(pred[i,:,:]), hr).item()
98+
mae += l1_crit(torch.Tensor(pred[i,:,:]), hr).item()
99+
ssim += ssim_criterion(torch.Tensor(pred[i,:,:]).unsqueeze(0), hr.unsqueeze(0)).item()
100+
101+
102+
#torch.save(torch.Tensor(pred[:128,:,:]).unsqueeze(1), './data/prediction/'+args.dataset+'_'+args.model_id+'_prediction.pt')
103+
if args.time:
104+
print(input_val.shape[0])
105+
mse *= 1/(input_val.shape[0]*args.time_steps)
106+
mae *= 1/(input_val.shape[0]*args.time_steps)
107+
ssim *= 1/(input_val.shape[0]*args.time_steps)
108+
if args.mass_violation:
109+
if args.time_sr:
110+
mass_violation *= 1/(input_val.shape[0]*args.time_steps)
111+
else:
112+
mass_violation *= 1/(input_val.shape[0]*2)
113+
else:
114+
mse *= 1/input_val.shape[0]
115+
mae *= 1/input_val.shape[0]
116+
ssim *= 1/input_val.shape[0]
117+
psnr = calculate_pnsr(mse, target_val.max() )
118+
scores = {'MSE':mse, 'RMSE':torch.sqrt(torch.Tensor([mse])), 'PSNR': psnr, 'MAE':mae, 'SSIM':1-ssim, 'Mass_violation': mass_violation}
119+
print(scores)
120+
create_report(scores, args)
121+
#np.save('./data/prediction/bic.npy', pred)
122+
123+
124+
def calculate_pnsr(mse, max_val):
125+
return 20 * torch.log10(max_val / torch.sqrt(torch.Tensor([mse])))
126+
127+
def create_report(scores, args):
128+
args_dict = args_to_dict(args)
129+
#combine scorees and args dict
130+
args_scores_dict = args_dict | scores
131+
#save dict
132+
save_dict(args_scores_dict, args)
133+
134+
def args_to_dict(args):
135+
return vars(args)
136+
137+
138+
def save_dict(dictionary, args):
139+
w = csv.writer(open('./data/score_log/'+args.model_id+'.csv', 'w'))
140+
# loop over dictionary keys and values
141+
for key, val in dictionary.items():
142+
# write every key and value to file
143+
w.writerow([key, val])
144+
145+
if __name__ == '__main__':
146+
args = add_arguments()
147+
main_scoring(args)

training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def optimizer_step(model, optimizer, criterion, inputs, targets, tepoch, args, c
121121
else:
122122
outputs = model(inputs)
123123
#print(outputs.shape, targets.shape)
124-
loss = get_loss(outputs, targets, args)#criterion(outputs, targets)
124+
loss = get_loss(outputs, targets, inputs,args)#criterion(outputs, targets)
125125
loss.backward()
126126
optimizer.step()
127127
#print(torch.mean((outputs-targets)**2))
@@ -254,7 +254,7 @@ def validate_model(model, criterion, data, best, patience, epoch, args, discrimi
254254
outputs, coeff = model(inputs)
255255
else:
256256
outputs = model(inputs)
257-
loss = get_loss(outputs, targets, args)
257+
loss = get_loss(outputs, targets, inputs, args)
258258
#print(torch.mean((outputs-targets)**2))
259259
running_loss += loss.item()
260260
#print('val:', loss.item())

utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ def load_data(args):
1515
if args.test:
1616
input_val = torch.load('./data/test/'+args.dataset+'/input_test.pt')
1717
target_val = torch.load('./data/test/'+args.dataset+'/target_test.pt')
18+
#target_val = torch.load('./data/val/dataset28/target_val.pt')
1819
else:
1920
input_val = torch.load('./data/val/'+args.dataset+'/input_val.pt')
21+
#target_val = torch.load('./data/val/dataset28/target_val.pt')
2022
target_val = torch.load('./data/val/'+args.dataset+'/target_val.pt')
21-
#target_val = torch.load('./data/val/'+args.dataset+'/target_val.pt')
2223
#define dimesions
2324
global train_shape_in , train_shape_out, val_shape_in, val_shape_in
2425
train_shape_in = input_train.shape
@@ -125,8 +126,8 @@ def load_model(args, discriminator=False):
125126
model = models.MixtureModel(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=args.noise, constraints=args.constraints, dim=1)
126127
elif args.model == 'gan':
127128
model = models.ResNet2(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=args.noise, constraints=args.constraints, dim=args.dim_channels)
128-
else:
129-
model = models.ResNet2(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=args.noise, constraints=args.constraints, dim=args.dim_channels)
129+
elif args.model == 'resnet2':
130+
model = models.ResNet2(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=args.noise, constraints=args.constraints, dim=args.dim_channels, cwindow_size= args.constraints_window_size)
130131
model.to(device)
131132
return model
132133

@@ -141,14 +142,14 @@ def get_criterion(args, discriminator=False):
141142
criterion = nn.MSELoss()
142143
return criterion
143144

144-
def mass_loss(output, true_value, args):
145+
def mass_loss(output, in_val, args):
145146
ds_out = torch.nn.functional.avg_pool2d(output[:,0,0,:,:], args.upsampling_factor)
146-
ds_true = torch.nn.functional.avg_pool2d(true_value[:,0,0,:,:], args.upsampling_factor)
147-
return torch.nn.functional.mse_loss(ds_out, ds_true)
147+
#ds_true = torch.nn.functional.avg_pool2d(true_value[:,0,0,:,:], args.upsampling_factor)
148+
return torch.nn.functional.mse_loss(ds_out, in_val)
148149

149-
def get_loss(output, true_value, args):
150+
def get_loss(output, true_value, in_val, args):
150151
if args.loss == 'mass_constraints':
151-
return args.alpha*torch.nn.functional.mse_loss(output, true_value) + (1-args.alpha)*mass_loss(output, true_value, args)
152+
return args.alpha*torch.nn.functional.mse_loss(output, true_value) + (1-args.alpha)*mass_loss(output, in_val[:,0,0,...], args)
152153
else:
153154
return torch.nn.functional.mse_loss(output, true_value)
154155

0 commit comments

Comments
 (0)