1+ from skimage import transform
2+ import numpy as np
3+ import torch
4+ from PIL import Image
5+ import torch .nn as nn
6+ import torchgeometry as tgm
7+ import argparse
8+ from torch .utils .data import DataLoader , TensorDataset
9+ import csv
10+
11+ def add_arguments ():
12+ parser = argparse .ArgumentParser ()
13+ parser .add_argument ("--dataset" , default = 'dataset34' , help = "choose a data set to use" )
14+ parser .add_argument ("--model" , default = 'resnet2_ood_softmax_test' )
15+ parser .add_argument ("--model_id" , default = 'dataset34_resnet2_ood_softmax_test_test' )
16+ parser .add_argument ("--time" , default = True )
17+ parser .add_argument ("--nn" , default = True )
18+ parser .add_argument ("--test" , default = True )
19+ parser .add_argument ("--time_steps" , type = int , default = 1 )
20+ parser .add_argument ("--mass_violation" , type = bool , default = True )
21+ parser .add_argument ("--factor" , type = int , default = 4 )
22+ parser .add_argument ("--time_sr" , default = False )
23+ #args for model loading
24+ return parser .parse_args ()
25+
26+ def main_scoring (args ):
27+ #n = 24
28+ input_val = torch .load ('./data/test/' + args .dataset + '/input_test.pt' )
29+ target_val = torch .load ('./data/test/' + args .dataset + '/target_test.pt' )
30+ #target_val = torch.load('./data/test/dataset28/target_test.pt')
31+ val_data = TensorDataset (input_val , target_val )
32+ pred = np .zeros (target_val .shape )
33+ print (pred .shape )
34+ factor = args .factor
35+ mse = 0
36+ mae = 0
37+ ssim = 0
38+ mass_violation = 0
39+ l2_crit = nn .MSELoss ()
40+ l1_crit = nn .L1Loss ()
41+ ssim_criterion = tgm .losses .SSIM (window_size = 11 , max_val = 130.83 , reduction = 'mean' )
42+ if args .nn :
43+
44+ pred = torch .load ('./data/prediction/' + args .dataset + '_' + args .model_id + '_test.pt' )
45+ pred = pred .detach ().cpu ().numpy ()
46+ print (pred .shape )
47+ for i ,(lr , hr ) in enumerate (val_data ):
48+ im = lr .numpy ()
49+ if args .time :
50+ #print(hr.shape)
51+ for j in range (args .time_steps ):
52+ if args .model == 'bilinear' :
53+ pred [i ,j ,0 ,:,:] = np .array (Image .fromarray (im [j ,0 ,...]).resize ((4 * lr .shape [2 ],4 * lr .shape [2 ]), Image .BILINEAR ))
54+ elif args .model == 'bicubic' :
55+ pred [i ,j ,0 ,:,:] = np .array (Image .fromarray (im [j ,0 ,...]).resize ((factor * lr .shape [2 ],factor * lr .shape [2 ]), Image .BICUBIC ))
56+ elif args .model == 'bicubic_frame' :
57+ if j == 0 :
58+ pred [i ,j ,0 ,:,:] = np .array (Image .fromarray (im [0 ,0 ,...]).resize ((4 * lr .shape [2 ],4 * lr .shape [2 ]), Image .BICUBIC ))
59+ elif j == 2 :
60+ pred [i ,j ,0 ,:,:] = np .array (Image .fromarray (im [1 ,0 ,...]).resize ((4 * lr .shape [2 ],4 * lr .shape [2 ]), Image .BICUBIC ))
61+ else :
62+ pred [i ,j ,0 ,:,:] = np .array (Image .fromarray (0.5 * (im [1 ,0 ,...]+ im [0 ,0 ,...])).resize ((4 * lr .shape [2 ],4 * lr .shape [2 ]), Image .BICUBIC ))
63+ elif args .model == 'kronecker' :
64+ pred [i ,j ,0 ,:,:] = np .kron (im [j ,0 ,...], np .ones ((4 ,4 )))
65+ elif args .model == 'frame_inter' :
66+ print (pred .shape , im .shape )
67+ pred [i ,j ,0 ,... ] = 0.5 * (im [0 ,0 ,...]+ im [1 ,0 ,...])
68+ '''
69+ mse_loss = l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
70+
71+ print(i, mse_loss)'''
72+ mse += l2_crit (torch .Tensor (pred [i ,j ,...]), hr [j ,...]).item ()
73+ mae += l1_crit (torch .Tensor (pred [i ,j ,...]), hr [j ,...]).item ()
74+ ssim += ssim_criterion (torch .Tensor (pred [i ,j :j + 1 ,...]), hr [j :j + 1 ,...]).item ()
75+ if args .mass_violation :
76+ if args .time_sr :
77+ if j == 0 :
78+ mass_violation += np .mean ( np .abs (transform .downscale_local_mean (pred [i ,j ,...], (1 ,args .factor ,args .factor )) - im [0 ,...]))
79+ elif j == 2 :
80+ mass_violation += np .mean ( np .abs (transform .downscale_local_mean (pred [i ,j ,...], (1 ,args .factor ,args .factor )) - im [1 ,...]))
81+ else :
82+ mass_violation += np .mean ( np .abs (transform .downscale_local_mean (pred [i ,j ,...], (1 ,args .factor ,args .factor )) - im [j ,...]))
83+ #print(l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item())
84+ elif args .model == 'frame_inter' :
85+ pred [i ,... ] = 0.5 * (im [0 :1 ,...]+ im [1 :2 ,...])
86+ mse += l2_crit (torch .Tensor (pred [i ,...]), hr ).item ()
87+ mae += l1_crit (torch .Tensor (pred [i ,...]), hr ).item ()
88+ ssim += ssim_criterion (torch .Tensor (pred [i ,...]), hr ).item ()
89+ #print(l1_crit(torch.Tensor(pred[i,...]), hr).item())
90+ else :
91+ if args .model == 'bilinear' :
92+ pred [i ,:,:] = np .array (Image .fromarray (im ).resize ((4 * lr .shape [1 ],4 * lr .shape [1 ]), Image .BILINEAR ))
93+ elif args .model == 'bicubic' :
94+ pred [i ,:,:] = np .array (Image .fromarray (im ).resize ((4 * lr .shape [1 ],4 * lr .shape [1 ]), Image .BICUBIC ))
95+ #elif args.model == 'kronecker':
96+
97+ mse += l2_crit (torch .Tensor (pred [i ,:,:]), hr ).item ()
98+ mae += l1_crit (torch .Tensor (pred [i ,:,:]), hr ).item ()
99+ ssim += ssim_criterion (torch .Tensor (pred [i ,:,:]).unsqueeze (0 ), hr .unsqueeze (0 )).item ()
100+
101+
102+ #torch.save(torch.Tensor(pred[:128,:,:]).unsqueeze(1), './data/prediction/'+args.dataset+'_'+args.model_id+'_prediction.pt')
103+ if args .time :
104+ print (input_val .shape [0 ])
105+ mse *= 1 / (input_val .shape [0 ]* args .time_steps )
106+ mae *= 1 / (input_val .shape [0 ]* args .time_steps )
107+ ssim *= 1 / (input_val .shape [0 ]* args .time_steps )
108+ if args .mass_violation :
109+ if args .time_sr :
110+ mass_violation *= 1 / (input_val .shape [0 ]* args .time_steps )
111+ else :
112+ mass_violation *= 1 / (input_val .shape [0 ]* 2 )
113+ else :
114+ mse *= 1 / input_val .shape [0 ]
115+ mae *= 1 / input_val .shape [0 ]
116+ ssim *= 1 / input_val .shape [0 ]
117+ psnr = calculate_pnsr (mse , target_val .max () )
118+ scores = {'MSE' :mse , 'RMSE' :torch .sqrt (torch .Tensor ([mse ])), 'PSNR' : psnr , 'MAE' :mae , 'SSIM' :1 - ssim , 'Mass_violation' : mass_violation }
119+ print (scores )
120+ create_report (scores , args )
121+ #np.save('./data/prediction/bic.npy', pred)
122+
123+
124+ def calculate_pnsr (mse , max_val ):
125+ return 20 * torch .log10 (max_val / torch .sqrt (torch .Tensor ([mse ])))
126+
127+ def create_report (scores , args ):
128+ args_dict = args_to_dict (args )
129+ #combine scorees and args dict
130+ args_scores_dict = args_dict | scores
131+ #save dict
132+ save_dict (args_scores_dict , args )
133+
134+ def args_to_dict (args ):
135+ return vars (args )
136+
137+
138+ def save_dict (dictionary , args ):
139+ w = csv .writer (open ('./data/score_log/' + args .model_id + '.csv' , 'w' ))
140+ # loop over dictionary keys and values
141+ for key , val in dictionary .items ():
142+ # write every key and value to file
143+ w .writerow ([key , val ])
144+
145+ if __name__ == '__main__' :
146+ args = add_arguments ()
147+ main_scoring (args )
0 commit comments