Skip to content

Commit af676da

Browse files
committed
28 11
1 parent bef3644 commit af676da

File tree

6 files changed

+292
-37
lines changed

6 files changed

+292
-37
lines changed

evaluate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def add_arguments():
5151
parser.add_argument("--factor", type=int, default=2)
5252
parser.add_argument("--time_sr", default=False)
5353
parser.add_argument("--constraints_window_size", default=4, type=int)
54+
parser.add_argument("--ensemble", default=False)
5455

5556
return parser.parse_args()
5657

mean_scores.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import csv
2+
3+
def add_arguments():
4+
parser = argparse.ArgumentParser()
5+
parser.add_argument("--model_id", default='test')
6+
return parser.parse_args()
7+
8+
#load five csvs and calculate means and stds
9+
10+
def main(args):
11+
number_runs = 5
12+
metrics = ['MSE', 'RMSE', 'PSNR', 'MAE', 'SSIM', 'Mass_violation', 'Mean bias', 'MS SSIM', 'Pearson corr', 'CRPS']
13+
dict_lists = {}
14+
for metric in metrics:
15+
means[metric] = []
16+
for i in range(number_runs):#
17+
filename = args.model_id + '_' + str(i)
18+
with open(filename,'r') as data:
19+
for line in csv.reader(data):
20+
if line[0] in metrics:
21+
means[line[0]].append(line[1])
22+
23+
#iterate over dict lists
24+
for metric, values in state_dict.items():
25+
dict_lists[metric+'_mean'] = np.mean(np.array(values))
26+
dict_lists[metric+'_std'] = np.std(np.array(values))
27+
28+
#save mean+std dict as csv
29+
save_dict(dict_lists, args)
30+
31+
32+
33+
def save_dict(dictionary, args):
34+
w = csv.writer(open('./data/score_log/'+args.model_id+'_means.csv', 'w'))
35+
# loop over dictionary keys and values
36+
for key, val in dictionary.items():
37+
# write every key and value to file
38+
w.writerow([key, val])
39+
40+
41+
if __name__ == '__main__':
42+
args = add_arguments()
43+
main(args)

models.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def __init__(self, upsampling_factor):
434434
self.pool = torch.nn.AvgPool2d(kernel_size=upsampling_factor)
435435
self.upsampling_factor = upsampling_factor
436436
def forward(self, y, lr):
437+
y = y.clone()
437438
out = self.pool(y)
438439
out = y*torch.kron(lr*1/out, torch.ones((self.upsampling_factor,self.upsampling_factor)).to('cuda'))
439440
return out
@@ -700,10 +701,84 @@ def forward(self, x, mr=None, z=None):
700701
# out[:,0,i,...] = self.constraints(out, x[:,0,i,...])
701702
#out[:,0,:,:] *= 16
702703
out = out.unsqueeze(1)
704+
return out
705+
706+
class ResNet3(nn.Module):
707+
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, cwindow_size=2):
708+
super(ResNet3, self).__init__()
709+
# First layer
710+
if noise:
711+
self.conv_trans0 = nn.ConvTranspose2d(100, 1, kernel_size=(32,32), padding=0, stride=1)
712+
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
713+
else:
714+
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
715+
#Residual Blocks
716+
self.res_blocks = nn.ModuleList()
717+
for k in range(number_residual_blocks):
718+
self.res_blocks.append(ResidualBlock(number_channels, number_channels))
719+
# Second conv layer post residual blocks
720+
self.conv2 = nn.Sequential(
721+
nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
722+
# Upsampling layers
723+
self.upsampling = nn.ModuleList()
724+
for k in range(1):
725+
self.upsampling.append(nn.ConvTranspose2d(number_channels, number_channels, kernel_size=3, padding=0, stride=3) )
726+
# Next layer after upper sampling
727+
self.conv3 = nn.Sequential(nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
728+
# Final output layer
729+
self.conv4 = nn.Conv2d(number_channels, dim, kernel_size=1, stride=1, padding=0)
730+
#optional renomralization layer
731+
self.is_constraints = False
732+
if constraints == 'softmax':
733+
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor, cwindow_size=cwindow_size)
734+
self.is_constraints = True
735+
elif constraints == 'enforce_op':
736+
self.constraints = EnforcementOperator(upsampling_factor=upsampling_factor)
737+
self.is_constraints = True
738+
elif constraints == 'add':
739+
self.constraints = AddDownscaleConstraints(upsampling_factor=upsampling_factor)
740+
self.is_constraints = True
741+
elif constraints == 'mult':
742+
self.constraints = MultDownscaleConstraints(upsampling_factor=upsampling_factor)
743+
self.is_constraints = True
744+
745+
self.dim = dim
746+
self.noise = noise
747+
748+
def forward(self, x, mr=None, z=None):
749+
if self.noise:
750+
out = self.conv_trans0(z)
751+
out = self.conv1(torch.cat(( x[:,0,...],out), dim=1))
752+
for layer in self.res_blocks:
753+
out = layer(out)
754+
out = self.conv2(out)
755+
for layer in self.upsampling:
756+
out = layer(out)
757+
out = self.conv3(out)
758+
out = self.conv4(out)
759+
if self.is_constraints:
760+
out = self.constraints(out, x[:,0,...])
703761
return out
762+
else:
763+
#print(x.shape)
764+
out = self.conv1(x[:,0,...])
765+
for layer in self.upsampling:
766+
out = layer(out)
767+
out = self.conv2(out)
768+
for layer in self.res_blocks:
769+
out = layer(out)
770+
out = self.conv3(out)
771+
out = self.conv4(out)
772+
if self.is_constraints:
773+
out[:,...] = self.constraints(out, x[:,0,...])
774+
#for i in range(self.dim):
775+
# out[:,0,i,...] = self.constraints(out, x[:,0,i,...])
776+
#out[:,0,:,:] *= 16
777+
out = out.unsqueeze(1)
778+
return out
704779

705780
class ResNetNoise(nn.Module):
706-
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1):
781+
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, cwindow_size=2):
707782
super(ResNetNoise, self).__init__()
708783
# First layer
709784

@@ -728,7 +803,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
728803
#optional renomralization layer
729804
self.is_constraints = False
730805
if constraints == 'softmax':
731-
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor)
806+
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor, cwindow_size=cwindow_size)
732807
self.is_constraints = True
733808
elif constraints == 'enforce_op':
734809
self.constraints = EnforcementOperator(upsampling_factor=upsampling_factor)
@@ -817,7 +892,7 @@ def forward(self, x):
817892

818893

819894
class ResNet2Up(nn.Module):
820-
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, output_mr=False):
895+
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, output_mr=False, cwindow_size= 2):
821896
super(ResNet2Up, self).__init__()
822897
#PART I
823898
self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
@@ -833,7 +908,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
833908

834909
self.is_constraints = False
835910
if constraints == 'softmax':
836-
self.constraints = SoftmaxConstraints(upsampling_factor=2)
911+
self.constraints = SoftmaxConstraints(upsampling_factor=2, cwindow_size=cwindow_size)
837912
self.is_constraints = True
838913
elif constraints == 'enforce_op':
839914
self.constraints = EnforcementOperator(upsampling_factor=2)
@@ -844,7 +919,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
844919

845920
self.is_single_constraints = False
846921
if constraints == 'softmax_single':
847-
self.constraints = SoftmaxConstraints(upsampling_factor=4)
922+
self.constraints = SoftmaxConstraints(upsampling_factor=4, cwindow_size=cwindow_size)
848923
self.is_single_constraints = True
849924
elif constraints == 'enforce_op_single':
850925
self.constraints = EnforcementOperator(upsampling_factor=4)
@@ -865,7 +940,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
865940
self.conv23 = nn.Sequential(nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))
866941
self.conv24 = nn.Conv2d(number_channels, dim, kernel_size=1, stride=1, padding=0)
867942
if constraints == 'softmax':
868-
self.constraints2 = SoftmaxConstraints(upsampling_factor=2)
943+
self.constraints2 = SoftmaxConstraints(upsampling_factor=2, cwindow_size=cwindow_size)
869944
elif constraints == 'enforce_op':
870945
self.constraints2 = EnforcementOperator(upsampling_factor=2)
871946
elif constraints == 'mult':

scoring.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import argparse
88
from torch.utils.data import DataLoader, TensorDataset
99
import csv
10+
from torchmetrics.functional import multiscale_structural_similarity_index_measure, structural_similarity_index_measure
11+
1012

1113
def add_arguments():
1214
parser = argparse.ArgumentParser()
@@ -20,28 +22,40 @@ def add_arguments():
2022
parser.add_argument("--mass_violation", type=bool, default=True)
2123
parser.add_argument("--factor", type=int, default=4)
2224
parser.add_argument("--time_sr", default=False)
25+
2326
#args for model loading
2427
return parser.parse_args()
2528

2629
def main_scoring(args):
2730
#n = 24
2831
input_val = torch.load('./data/test/'+args.dataset+'/input_test.pt')
2932
target_val = torch.load('./data/test/'+args.dataset+'/target_test.pt')
33+
#target_val = torch.load('./data/test/'+args.dataset+'/target_test.pt')
3034
#target_val = torch.load('./data/test/dataset28/target_test.pt')
3135
val_data = TensorDataset(input_val, target_val)
3236
pred = np.zeros(target_val.shape)
37+
max_val = target_val.max()
38+
min_val = target_val.min()
3339
print(pred.shape)
3440
factor = args.factor
3541
mse = 0
3642
mae = 0
3743
ssim = 0
44+
mean_bias = 0
3845
mass_violation = 0
46+
ms_ssim = 0
47+
corr = 0
48+
crps = 0
49+
3950
l2_crit = nn.MSELoss()
4051
l1_crit = nn.L1Loss()
41-
ssim_criterion = tgm.losses.SSIM(window_size=11, max_val=130.83, reduction='mean')
52+
#ssim_criterion = StructuralSimilarityIndexMeasure() #tgm.losses.SSIM(window_size=11, max_val=max_val, reduction='mean')
4253
if args.nn:
43-
44-
pred = torch.load('./data/prediction/'+args.dataset+'_'+args.model_id+'_test.pt')
54+
if args.ensemble:
55+
en_pred = torch.load('./data/prediction/'+args.dataset+'_'+args.model_id+'_test_ensemble.pt')
56+
pred = torch.mean(en_pred, dim=1)
57+
else:
58+
pred = torch.load('./data/prediction/'+args.dataset+'_'+args.model_id+'_test.pt')
4559
pred = pred.detach().cpu().numpy()
4660
print(pred.shape)
4761
for i,(lr, hr) in enumerate(val_data):
@@ -71,7 +85,13 @@ def main_scoring(args):
7185
print(i, mse_loss)'''
7286
mse += l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
7387
mae += l1_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
74-
ssim += ssim_criterion(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...]).item()
88+
mean_bias += torch.mean( hr[j,...]-torch.abs(torch.Tensor(pred[i,j,...])))
89+
corr += pearsonr(torch.Tensor(pred[i,j,...]).flatten(), hr[j,...].flatten())
90+
ms_ssim += multiscale_structural_similarity_index_measure(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...], data_range=max_val-min_val, kernel_size=11, betas=(0.2856, 0.3001, 0.2363))#0.0448, 0.2856, 0.3001, 0.2363))
91+
ssim += structural_similarity_index_measure(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...] , data_range=max_val-min_val, kernel_size=11)#ssim_criterion(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...]).item()
92+
if args.ensemble:
93+
crps_ens = crps_ensemble(hr[j,0,0,...].numpy(), np.swapaxis(np.swapaxis(pred[i,:,j,0,0,...], 0,1),1,2))
94+
crps += np.mean(crps_ens)
7595
if args.mass_violation:
7696
if args.time_sr:
7797
if j==0:
@@ -105,24 +125,81 @@ def main_scoring(args):
105125
mse *= 1/(input_val.shape[0]*args.time_steps)
106126
mae *= 1/(input_val.shape[0]*args.time_steps)
107127
ssim *= 1/(input_val.shape[0]*args.time_steps)
128+
mean_bias *= 1/(input_val.shape[0]*args.time_steps)
129+
corr *= 1/(input_val.shape[0]*args.time_steps)
130+
ms_ssim *= 1/(input_val.shape[0]*args.time_steps)
131+
108132
if args.mass_violation:
109133
if args.time_sr:
110134
mass_violation *= 1/(input_val.shape[0]*args.time_steps)
111135
else:
112-
mass_violation *= 1/(input_val.shape[0]*2)
136+
mass_violation *= 1/(input_val.shape[0]) #what is the 2 doing here?
113137
else:
114138
mse *= 1/input_val.shape[0]
115139
mae *= 1/input_val.shape[0]
116140
ssim *= 1/input_val.shape[0]
117141
psnr = calculate_pnsr(mse, target_val.max() )
118-
scores = {'MSE':mse, 'RMSE':torch.sqrt(torch.Tensor([mse])), 'PSNR': psnr, 'MAE':mae, 'SSIM':1-ssim, 'Mass_violation': mass_violation}
142+
scores = {'MSE':mse, 'RMSE':torch.sqrt(torch.Tensor([mse])), 'PSNR': psnr, 'MAE':mae, 'SSIM':ssim, 'Mass_violation': mass_violation, 'Mean bias': mean_bias, 'MS SSIM': ms_ssim, 'Pearson corr': corr, 'CRPS': crps}
119143
print(scores)
120144
create_report(scores, args)
121145
#np.save('./data/prediction/bic.npy', pred)
122146

123147

124148
def calculate_pnsr(mse, max_val):
125149
return 20 * torch.log10(max_val / torch.sqrt(torch.Tensor([mse])))
150+
151+
def pearsonr(x, y):
152+
"""
153+
Mimics `scipy.stats.pearsonr`
154+
Arguments
155+
---------
156+
x : 1D torch.Tensor
157+
y : 1D torch.Tensor
158+
Returns
159+
-------
160+
r_val : float
161+
pearsonr correlation coefficient between x and y
162+
163+
Scipy docs ref:
164+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html
165+
166+
Scipy code ref:
167+
https://github.com/scipy/scipy/blob/v0.19.0/scipy/stats/stats.py#L2975-L3033
168+
Example:
169+
>>> x = np.random.randn(100)
170+
>>> y = np.random.randn(100)
171+
>>> sp_corr = scipy.stats.pearsonr(x, y)[0]
172+
>>> th_corr = pearsonr(torch.from_numpy(x), torch.from_numpy(y))
173+
>>> np.allclose(sp_corr, th_corr)
174+
"""
175+
mean_x = torch.mean(x)
176+
mean_y = torch.mean(y)
177+
xm = x.sub(mean_x)
178+
ym = y.sub(mean_y)
179+
r_num = xm.dot(ym)
180+
r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
181+
r_val = r_num / r_den
182+
return r_val
183+
184+
def crps_ensemble(observation, forecasts):
185+
fc = forecasts.copy()
186+
fc.sort(axis=-1)
187+
obs = observation
188+
fc_below = fc<obs[...,None]
189+
crps = np.zeros_like(obs)
190+
191+
for i in range(fc.shape[-1]):
192+
below = fc_below[...,i]
193+
weight = ((i+1)**2 - i**2) / fc.shape[-1]**2
194+
crps[below] += weight * (obs[below]-fc[...,i][below])
195+
196+
for i in range(fc.shape[-1]-1,-1,-1):
197+
above = ~fc_below[...,i]
198+
k = fc.shape[-1]-1-i
199+
weight = ((k+1)**2 - k**2) / fc.shape[-1]**2
200+
crps[above] += weight * (fc[...,i][above]-obs[above])
201+
202+
return crps
126203

127204
def create_report(scores, args):
128205
args_dict = args_to_dict(args)

0 commit comments

Comments
 (0)