Skip to content

Commit 8640e51

Browse files
committed
corrections
1 parent cb9815f commit 8640e51

File tree

6 files changed

+143
-47
lines changed

6 files changed

+143
-47
lines changed

README.md

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ This code belongs to a paper currently under review, a preprint can be found at:
44

55
Abstract: *The availability of reliable, high-resolution climate and weather data is important to inform long-term decisions on climate adaptation and mitigation and to guide rapid responses to extreme events. Forecasting models are limited by computational costs and, therefore, often generate coarse-resolution predictions. Statistical downscaling can provide an efficient method of upsampling low-resolution data. In this field, deep learning has been applied successfully, often using image super-resolution methods from computer vision. However, despite achieving visually compelling results in some cases, such models frequently violate conservation laws when predicting physical variables. In order to conserve physical quantities, we develop methods that guarantee physical constraints are satisfied by a deep learning downscaling model while also improving their performance according to traditional metrics. We compare different constraining approaches and demonstrate their applicability across different neural architectures as well as a variety of climate and weather data sets. While our novel methodologies enable faster and more accurate climate predictions, we also show how they can improve super-resolution for satellite data and standard data sets.*
66

7+
## Setup
8+
9+
Clone the repository and install the requirements
10+
```sh
11+
$ git clone https://github.com/RolnickLab/constrained-downscaling.git
12+
$ cd constrained-downscaling
13+
$ conda env create -f requirements.yml
14+
$ conda activate constrained-ds
15+
```
16+
717
## Get the data
818

919
One of our main data sets, ERA5 total columnt water, 4x upsampling, can be downloaded in a ML-ready form at: https://drive.google.com/file/d/1IENhP1-aTYyqOkRcnmCIvxXkvUW2Qbdx/view?usp=sharing
@@ -22,38 +32,29 @@ $ rm era5_sr_data.zip
2232

2333
Other data sets are available upon request from the author or can be generated by using public sources for ERA5 (https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-single-levels?tab=form.) and NorESM (https://esgf-index1.ceda.ac.uk/search/cmip6-ceda/) data.
2434

25-
## Setup
26-
27-
Clone the repository and install the requirements
28-
```sh
29-
$ git clone https://github.com/RolnickLab/constrained-downscaling.git
30-
$ cd constrained-downscaling
31-
$ conda env create -f requirements.yml
32-
$ conda activate constrained_ds
33-
```
3435

3536
## Run training
3637

3738
To run our standard CNN withour constrained run
3839

3940
```sh
40-
$ python main.py --dataset era5_twc --model cnn --model_id twc_cnn_noconstraints --constraints none
41+
$ python main.py --dataset era5_sr_data --model cnn --model_id twc_cnn_noconstraints --constraints none
4142
```
4243

4344
to run with softmax constraining (hard constraining) run
4445

4546
```sh
46-
$ python main.py --dataset era5_twc --model cnn --model_id twc_cnn_softmaxconstraints --constraints softmax
47+
$ python main.py --dataset era5_sr_data --model cnn --model_id twc_cnn_softmaxconstraints --constraints softmax
4748
```
4849

4950
to run with soft constraining run, with a factor of alpha run
5051

5152
```sh
52-
$ python main.py --dataset era5_twc --model cnn --model_id twc_cnn_softconstraints --constraints soft --loss mass_constraints --alpha 0.99
53+
$ python main.py --dataset era5_sr_data --model cnn --model_id twc_cnn_softconstraints --constraints soft --loss mass_constraints --alpha 0.99
5354
```
5455

5556
For other setups:
56-
--model can be either cnn, gan, convgru, flowconvgru
57+
--model can be either cnn, gan, convgru, flowconvgru (last two require different data sets)
5758
--constraints can be none, softmax, gh, mult, add, soft
5859
other arguents are --epochs, --lr (learning rate), --number_residual_blocks, --weight_decay
5960

main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from utils import load_data
33
import numpy as np
44
import argparse
5+
import os
56
import torch
67

78
def add_arguments():
@@ -22,10 +23,15 @@ def add_arguments():
2223
parser.add_argument("--alpha", default=0.99, type=float)
2324
parser.add_argument("--test_val_train", default="val")
2425
parser.add_argument("--training_evalonly", default="training")
26+
parser.add_argument("--dim_channels", default=1, type=int)
2527
return parser.parse_args()
2628

2729
def main(args):
2830
#load data
31+
if not os.path.exists('./models'):
32+
os.makedirs('./models')
33+
if not os.path.exists('./data/prediction'):
34+
os.makedirs('./data/prediction')
2935
if args.training_evalonly == 'training':
3036
data = load_data(args)
3137
#run training

models.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,15 @@ def forward(self, y, lr):
7979
return out
8080

8181
class SoftmaxConstraints(nn.Module):
82-
def __init__(self, upsampling_factor, cwindow_size, exp_factor=1):
82+
def __init__(self, upsampling_factor, exp_factor=1):
8383
super(SoftmaxConstraints, self).__init__()
84-
self.pool = torch.nn.AvgPool2d(kernel_size=cwindow_size)
85-
self.lr_pool = torch.nn.AvgPool2d(kernel_size=int(cwindow_size/upsampling_factor))
8684
self.upsampling_factor = upsampling_factor
87-
self.cwindow_size = cwindow_size
8885
self.exp_factor = exp_factor
8986
def forward(self, y, lr):
9087
y = torch.exp(y*self.exp_factor)
9188
sum_y = self.pool(y)
9289
lr_sum = self.lr_pool(lr)
93-
out = y*torch.kron(lr_sum*1/sum_y, torch.ones((self.cwindow_size,self.cwindow_size)).to('cuda'))
90+
out = y*torch.kron(lr_sum*1/sum_y, torch.ones((self.upsampling_factor,self.upsampling_factor)).to('cuda'))
9491
return out
9592

9693

@@ -110,7 +107,7 @@ def forward(self, y):
110107

111108

112109
class ResNet(nn.Module):
113-
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, cwindow_size=4):
110+
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1):
114111
super(ResNet, self).__init__()
115112
# First layer
116113
if noise:
@@ -136,7 +133,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
136133
#optional renomralization layer
137134
self.is_constraints = False
138135
if constraints == 'softmax':
139-
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor, cwindow_size=cwindow_size)
136+
self.constraints = SoftmaxConstraints(upsampling_factor=upsampling_factor)
140137
self.is_constraints = True
141138
elif constraints == 'enforce_op':
142139
self.constraints = EnforcementOperator(upsampling_factor=upsampling_factor)

requirements.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: condtrained_ds
1+
name: constrained-ds
22
channels:
33
- pytorch
44
- nvidia

training.py

Lines changed: 108 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils import process_for_training, is_gan, is_noisegan, load_model, get_optimizer, get_criterion, process_for_eval, get_loss, load_data
1+
from utils import process_for_training, is_gan, load_model, get_optimizer, get_criterion, process_for_eval, get_loss, load_data
22
import models
33
import numpy as np
44
from tqdm import tqdm
@@ -7,7 +7,9 @@
77
import torchgeometry as tgm
88
import csv
99
import numpy as np
10-
from scoring import main_scoring
10+
from torch.utils.data import DataLoader, TensorDataset
11+
from torchmetrics.functional import multiscale_structural_similarity_index_measure, structural_similarity_index_measure
12+
from skimage import transform
1113
device = 'cuda'
1214

1315
def run_training(args, data):
@@ -27,7 +29,6 @@ def run_training(args, data):
2729
running_loss = 0
2830
running_discr_loss = 0
2931
running_adv_loss = 0
30-
running_mass_loss = 0
3132
for (inputs, targets) in data[0]:
3233
inputs, targets = process_for_training(inputs, targets)
3334
if is_gan(args):
@@ -37,7 +38,6 @@ def run_training(args, data):
3738
else:
3839
loss = optimizer_step(model, optimizer, criterion, inputs, targets, data[0], args)
3940
running_loss += loss
40-
running_mass_loss += mass_loss
4141
loss = running_loss/len(data[0])
4242
if is_gan(args):
4343
dicsr_loss = running_discr_loss/len(data)
@@ -51,11 +51,10 @@ def run_training(args, data):
5151
val_loss = validate_model(model, criterion, data[1], best, epoch, args, discriminator_model, criterion_discr)
5252
else:
5353
val_loss = validate_model(model, criterion, data[1], best, epoch, args)
54-
val_losses.append(val_loss)
5554
print('Val loss: {:.5f}'.format(val_loss))
5655
checkpoint(model, val_loss, best, args, epoch)
5756
best = np.minimum(best, val_loss)
58-
data = load_data(args.test_val_train, args)
57+
data = load_data(args)
5958
scores = evaluate_model( data, args)
6059

6160

@@ -110,7 +109,7 @@ def validate_model(model, criterion, data, best, epoch, args, discriminator_mode
110109
adversarial_loss = criterion_discr(fake_output.detach(), real_label)
111110
loss += args.adv_factor * adversarial_loss
112111
else:
113-
outputs = model(inputs)
112+
outputs = model(inputs)
114113
loss = get_loss(outputs, targets, inputs, args)
115114
running_loss += loss.item()
116115
loss = running_loss/len(data)
@@ -120,12 +119,11 @@ def validate_model(model, criterion, data, best, epoch, args, discriminator_mode
120119
Tensor = torch.cuda.FloatTensor
121120

122121
def checkpoint(model, val_loss, best, args, epoch):
123-
print(val_loss, best)
124122
if val_loss < best:
125123
checkpoint = {'model': model,'state_dict': model.state_dict()}
126124
torch.save(checkpoint, './models/'+args.model_id+'.pth')
127125

128-
def evaluate_model(data, args, add_string=None):
126+
def evaluate_model(data, args):
129127
model = load_model(args)
130128
load_weights(model, args.model_id)
131129
model.eval()
@@ -149,38 +147,133 @@ def evaluate_model(data, args, add_string=None):
149147
else:
150148
torch.save(full_pred, './data/prediction/'+args.dataset+'_'+args.model_id+ '_' + args.test_val_train+'.pt')
151149
calculate_scores(args)
150+
151+
def calculate_scores(args):
152+
input_val = torch.load('./data/'+args.dataset+'/'+ args.test_val_train+'/input_'+ args.test_val_train+'.pt')
153+
target_val = torch.load('./data/'+args.dataset+'/'+ args.test_val_train+'/target_'+ args.test_val_train+'.pt')
154+
val_data = TensorDataset(input_val, target_val)
155+
pred = np.zeros(target_val.shape)
156+
max_val = target_val.max()
157+
min_val = target_val.min()
158+
mse = 0
159+
mae = 0
160+
ssim = 0
161+
mean_bias = 0
162+
mean_abs_bias = 0
163+
mass_violation = 0
164+
ms_ssim = 0
165+
corr = 0
166+
crps = 0
167+
neg_mean = 0
168+
neg_num = 0
169+
170+
l2_crit = nn.MSELoss()
171+
l1_crit = nn.L1Loss()
172+
173+
if args.model == 'gan':
174+
en_pred = torch.load('./data/prediction/'+args.dataset+'_'+args.model_id+ '_' + args.test_val_train+'_ensemble.pt')
175+
pred = torch.mean(en_pred, dim=1)
176+
en_pred = en_pred.detach().cpu().numpy()
177+
else:
178+
pred = torch.load('./data/prediction/'+args.dataset+'_'+args.model_id+ '_' + args.test_val_train+'.pt')
179+
180+
pred = pred.detach().cpu().numpy()
181+
j = 0
182+
for i,(lr, hr) in enumerate(val_data):
183+
im = lr.numpy()
184+
mse += l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
185+
mae += l1_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()
186+
mean_bias += torch.mean( hr[j,...]-torch.Tensor(pred[i,j,...]))
187+
mean_abs_bias += torch.abs(torch.mean( hr[j,...]-torch.Tensor(pred[i,j,...])))
188+
corr += pearsonr(torch.Tensor(pred[i,j,...]).flatten(), hr[j,...].flatten())
189+
ms_ssim += multiscale_structural_similarity_index_measure(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...], data_range=max_val-min_val, kernel_size=11, betas=(0.2856, 0.3001, 0.2363))
190+
ssim += structural_similarity_index_measure(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...] , data_range=max_val-min_val, kernel_size=11)
191+
neg_num += np.sum(pred[i,j,...] < 0)
192+
neg_mean += np.sum(pred[pred < 0])/(pred.shape[-1]*pred.shape[-1])
193+
if args.model == 'gan':
194+
crps_ens = crps_ensemble(hr[j,0,...].numpy(), en_pred[i,:,j,0,...])
195+
crps += crps_ens
196+
197+
mass_violation += np.mean( np.abs(transform.downscale_local_mean(pred[i,j,...], (1,args.upsampling_factor,args.upsampling_factor)) -im[j,...]))
198+
199+
mse *= 1/input_val.shape[0]
200+
mae *= 1/input_val.shape[0]
201+
ssim *= 1/input_val.shape[0]
202+
mean_bias *= 1/input_val.shape[0]
203+
mean_abs_bias *= 1/input_val.shape[0]
204+
corr *= 1/input_val.shape[0]
205+
ms_ssim *= 1/input_val.shape[0]
206+
crps *= 1/input_val.shape[0]
207+
neg_mean *= 1/input_val.shape[0]
208+
mass_violation *= 1/input_val.shape[0]
209+
psnr = calculate_pnsr(mse, target_val.max() )
210+
rmse = torch.sqrt(torch.Tensor([mse])).numpy()[0]
211+
ssim = float(ssim.numpy())
212+
ms_ssim =float( ms_ssim.numpy())
213+
psnr = psnr.numpy()
214+
corr = float(corr.numpy())
215+
mean_bias = float(mean_bias.numpy())
216+
mean_abs_bias = float(mean_abs_bias.numpy())
217+
scores = {'MSE':mse, 'RMSE':rmse, 'PSNR': psnr[0], 'MAE':mae, 'SSIM':ssim, 'MS SSIM': ms_ssim, 'Pearson corr': corr, 'Mean bias': mean_bias, 'Mean abs bias': mean_abs_bias, 'Mass_violation': mass_violation, 'neg mean': neg_mean, 'neg num': neg_num,'CRPS': crps}
218+
print(scores)
219+
create_report(scores, args)
152220

153221

154222
def calculate_pnsr(mse, max_val):
155223
return 20 * torch.log10(max_val / torch.sqrt(torch.Tensor([mse])))
156224

157-
def create_report(scores, args, add_string=None):
225+
def create_report(scores, args):
158226
args_dict = args_to_dict(args)
159227
#combine scorees and args dict
160228
args_scores_dict = args_dict | scores
161229
#save dict
162-
save_dict(args_scores_dict, args, add_string)
230+
save_dict(args_scores_dict, args)
163231

164232
def args_to_dict(args):
165233
return vars(args)
166234

167235

168236
def save_dict(dictionary, args):
169-
170-
w = csv.writer(open('./data/score_log/'+args.model_id+'.csv', 'w'))
171-
237+
w = csv.writer(open('./data/score_log/'+args.model_id+'.csv', 'w'))
172238
# loop over dictionary keys and values
173239
for key, val in dictionary.items():
174240
# write every key and value to file
175241
w.writerow([key, val])
176242

177243
def load_weights(model, model_id):
178-
PATH = '/home/harder/constraint_generative_ml/models/'+model_id+'.pth'
244+
PATH = '/home/harder/constrained-downscaling/models/'+model_id+'.pth'
179245
checkpoint = torch.load(PATH) # ie, model_best.pth.tar
180246
model.load_state_dict(checkpoint['state_dict'])
181247
model.to('cuda')
182248
return model
183249

250+
def pearsonr(x, y):
251+
mean_x = torch.mean(x)
252+
mean_y = torch.mean(y)
253+
xm = x.sub(mean_x)
254+
ym = y.sub(mean_y)
255+
r_num = xm.dot(ym)
256+
r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
257+
r_val = r_num / r_den
258+
return r_val
259+
260+
def crps_ensemble(observation, forecasts):
261+
fc = forecasts.copy()
262+
fc.sort(axis=0)
263+
obs = observation
264+
fc_below = fc<obs[None,...]
265+
crps = np.zeros_like(obs)
266+
for i in range(fc.shape[0]):
267+
below = fc_below[i,...]
268+
weight = ((i+1)**2 - i**2) / fc.shape[-1]**2
269+
crps[below] += weight * (obs[below]-fc[i,...][below])
270+
271+
for i in range(fc.shape[0]-1,-1,-1):
272+
above = ~fc_below[i,...]
273+
k = fc.shape[0]-1-i
274+
weight = ((k+1)**2 - k**2) / fc.shape[0]**2
275+
crps[above] += weight * (fc[i,...][above]-obs[above])
276+
return np.mean(crps)
184277

185278

186279

utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,18 @@
22
import torch.optim as optim
33
import torch.nn as nn
44
import models
5-
import Learnable_basis
65
from torch.utils.data import DataLoader, TensorDataset
76
device = 'cuda'
87

98
def load_data(args):
10-
input_train = torch.load('./data/train/input_train.pt')
11-
target_train = torch.load('./data/train/target_train.pt')
12-
9+
input_train = torch.load('./data/'+args.dataset+'/train/input_train.pt')
10+
target_train = torch.load('./data/'+args.dataset+'/train/target_train.pt')
1311
if args.test_val_train == 'test':
14-
input_val = torch.load('./data/test/input_test.pt')
15-
target_val = torch.load('./data/test/target_test.pt')
12+
input_val = torch.load('./data/'+args.dataset+'/test/input_test.pt')
13+
target_val = torch.load('./data/'+args.dataset+'/test/target_test.pt')
1614
elif args.test_val_train == 'val':
17-
input_val = torch.load('./data/val/input_val.pt')
18-
target_val = torch.load('./data/val/target_val.pt')
15+
input_val = torch.load('./data/'+args.dataset+'/val/input_val.pt')
16+
target_val = torch.load('./data/'+args.dataset+'/val/target_val.pt')
1917
elif args.test_val_train == 'train':
2018
input_val = input_train
2119
target_val = target_train
@@ -59,9 +57,10 @@ def load_model(args, discriminator=False):
5957
elif args.model == 'flowconvgru':
6058
model = models.TimeEndToEndModel( number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, time_steps=3, constraints=args.constraints)
6159
elif args.model == 'gan':
62-
model = models.ResNet(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=args.noise, constraints=args.constraints, dim=args.dim_channels)
60+
model = models.ResNet(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=(args.model=='gan'), constraints=args.constraints, dim=args.dim_channels)
6361
elif args.model == 'cnn':
64-
model = models.ResNet(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=args.noise, constraints=args.constraints, dim=args.dim_channels, cwindow_size= args.constraints_window_size)
62+
model = models.ResNet(number_channels=args.number_channels, number_residual_blocks=args.number_residual_blocks, upsampling_factor=args.upsampling_factor, noise=(args.model=='gan'), constraints=args.constraints, dim=args.dim_channels)
63+
model = model.to(device)
6564
return model
6665

6766
def get_optimizer(args, model):
@@ -91,7 +90,7 @@ def process_for_training(inputs, targets):
9190
return inputs, targets
9291

9392
def process_for_eval(outputs, targets, mean, std, max_val, args):
94-
if args.gan:
93+
if args.model == 'gan':
9594
outputs[:,:,0,0,...] = outputs[:,0,0,...]*(max_val[0].to(device)-min_val[0].to(device))+min_val[0].to(device)
9695
targets[:,0,0,...] = targets[:,0,0,...]*(max_val[0].to(device)-min_val[0].to(device))+min_val[0].to(device)
9796
else:

0 commit comments

Comments
 (0)