Skip to content

Commit cb9815f

Browse files
committed
added requirements
1 parent 048a461 commit cb9815f

File tree

9 files changed

+221
-497
lines changed

9 files changed

+221
-497
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
./ipynb_checkpoints

README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,41 @@ $ conda env create -f requirements.yml
3232
$ conda activate constrained_ds
3333
```
3434

35-
## Run training
35+
## Run training
3636

3737
To run our standard CNN withour constrained run
3838

3939
```sh
40-
$ python main.py --dataset era5_twc --model_id twc_noconstraints --constraints none
40+
$ python main.py --dataset era5_twc --model cnn --model_id twc_cnn_noconstraints --constraints none
4141
```
4242

43-
to run with softmax constraining run
43+
to run with softmax constraining (hard constraining) run
4444

4545
```sh
46-
$ python main.py --dataset era5_twc --model_id twc_smconstraints --constraints softmax
46+
$ python main.py --dataset era5_twc --model cnn --model_id twc_cnn_softmaxconstraints --constraints softmax
4747
```
4848

49+
to run with soft constraining run, with a factor of alpha run
50+
51+
```sh
52+
$ python main.py --dataset era5_twc --model cnn --model_id twc_cnn_softconstraints --constraints soft --loss mass_constraints --alpha 0.99
53+
```
54+
55+
For other setups:
56+
--model can be either cnn, gan, convgru, flowconvgru
57+
--constraints can be none, softmax, gh, mult, add, soft
58+
other arguents are --epochs, --lr (learning rate), --number_residual_blocks, --weight_decay
4959

5060
## Run inference
5161

62+
An example evaluation for the unconstrained model:
63+
5264
```sh
53-
$ python evaluatee.py --dataset era5_twc --model_id twc_noconstraints --constraints none
65+
$ python main.py --training_evalonly evalonly --dataset era5_twc --model cnn --model_id twc_cnn_noconstraints --constraints none
5466
```
5567

68+
It produces a csv file with all metrics on either validation or test set.
69+
5670
## Citation
5771

5872
If you find this repository helpful please consider to cite our work
@@ -63,5 +77,7 @@ If you find this repository helpful please consider to cite our work
6377
publisher = {arXiv},
6478
year = {2022}
6579
}
80+
81+
6682

6783

evaluate.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

main.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,35 @@
66

77
def add_arguments():
88
parser = argparse.ArgumentParser()
9-
parser.add_argument("--dataset", default='era5_twc', help="choose a data set to use")
10-
parser.add_argument("--scale", default='minmax', help="standard, minmax, none")
11-
parser.add_argument("--model", default='resnet2')
12-
parser.add_argument("--model_id", default='test')
9+
parser.add_argument("--dataset", default="era5_twc", help="choose a data set to use")
10+
parser.add_argument("--model", default="resnet2")
11+
parser.add_argument("--model_id", default="test")
1312
parser.add_argument("--upsampling_factor", default=4, type=int)
14-
parser.add_argument("--constraints", default='none')
13+
parser.add_argument("--constraints", default="none")
1514
parser.add_argument("--number_channels", default=32, type=int)
1615
parser.add_argument("--number_residual_blocks", default=4, type=int)
1716
parser.add_argument("--lr", default=0.001, help="learning rate", type=float)
18-
parser.add_argument("--loss", default='mse')
19-
parser.add_argument("--optimizer", default='adam')
17+
parser.add_argument("--loss", default="mse")
18+
parser.add_argument("--optimizer", default="adam")
2019
parser.add_argument("--weight_decay", default=1e-9, type=float)
2120
parser.add_argument("--batch_size", default=64, type=int)
2221
parser.add_argument("--epochs", default=1, type=int)
23-
parser.add_argument("--alpha", default=0.5, type=float)
22+
parser.add_argument("--alpha", default=0.99, type=float)
2423
parser.add_argument("--test_val_train", default="val")
25-
parser.add_argument("--eval", default=False, help="run for training or evaluation")
24+
parser.add_argument("--training_evalonly", default="training")
2625
return parser.parse_args()
2726

2827
def main(args):
2928
#load data
30-
data = load_data(args)
31-
32-
#run training
33-
run_training(args, data)
34-
29+
if args.training_evalonly == 'training':
30+
data = load_data(args)
31+
#run training
32+
run_training(args, data)
33+
else:
34+
data = load_data(args)
35+
#run training
36+
evaluate_model(args, data)
37+
3538
if __name__ == '__main__':
3639
args = add_arguments()
3740
main(args)

models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def forward(self, y):
109109
return torch.sum(y, dim=1).unsqueeze(1)
110110

111111

112-
class ResNet2(nn.Module):
112+
class ResNet(nn.Module):
113113
def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, noise=False, constraints='none', dim=1, cwindow_size=4):
114-
super(ResNet2, self).__init__()
114+
super(ResNet, self).__init__()
115115
# First layer
116116
if noise:
117117
self.conv_trans0 = nn.ConvTranspose2d(100, 1, kernel_size=(32,32), padding=0, stride=1)

requirements.yml

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
name: condtrained_ds
2+
channels:
3+
- pytorch
4+
- nvidia
5+
- defaults
6+
- conda-forge
7+
dependencies:
8+
- _libgcc_mutex=0.1=main
9+
- _openmp_mutex=4.5=1_gnu
10+
- backcall=0.2.0=pyhd3eb1b0_0
11+
- blas=1.0=mkl
12+
- brotli=1.0.9=he6710b0_2
13+
- bzip2=1.0.8=h7b6447c_0
14+
- ca-certificates=2022.9.24=ha878542_0
15+
- certifi=2022.9.24=pyhd8ed1ab_0
16+
- cloudpickle=2.0.0=pyhd3eb1b0_0
17+
- cudatoolkit=11.1.1=h6406543_8
18+
- cycler=0.11.0=pyhd3eb1b0_0
19+
- cytoolz=0.11.0=py39h27cfd23_0
20+
- dask-core=2021.10.0=pyhd3eb1b0_0
21+
- dbus=1.13.18=hb2f20db_0
22+
- debugpy=1.5.1=py39h295c915_0
23+
- decorator=5.1.0=pyhd3eb1b0_0
24+
- entrypoints=0.3=py39h06a4308_0
25+
- expat=2.4.1=h2531618_2
26+
- ffmpeg=4.2.2=h20bf706_0
27+
- fontconfig=2.13.1=h6c09931_0
28+
- fonttools=4.25.0=pyhd3eb1b0_0
29+
- freetype=2.10.4=h5ab3b9f_0
30+
- fsspec=2021.10.1=pyhd3eb1b0_0
31+
- glib=2.69.1=h5202010_0
32+
- gmp=6.2.1=h2531618_2
33+
- gnutls=3.6.15=he1e5248_0
34+
- gst-plugins-base=1.14.0=h8213a91_2
35+
- gstreamer=1.14.0=h28cd5cc_2
36+
- icu=58.2=he6710b0_3
37+
- imageio=2.9.0=pyhd3eb1b0_0
38+
- intel-openmp=2021.2.0=h06a4308_610
39+
- ipykernel=6.4.1=py39h06a4308_1
40+
- ipython=7.29.0=py39hb070fc8_0
41+
- ipython_genutils=0.2.0=pyhd3eb1b0_1
42+
- jedi=0.18.0=py39h06a4308_1
43+
- joblib=1.1.0=pyhd3eb1b0_0
44+
- jpeg=9b=h024ee3a_2
45+
- jupyter_client=7.1.0=pyhd3eb1b0_0
46+
- jupyter_core=4.9.1=py39h06a4308_0
47+
- kiwisolver=1.3.1=py39h2531618_0
48+
- lame=3.100=h7b6447c_0
49+
- lcms2=2.12=h3be6417_0
50+
- ld_impl_linux-64=2.35.1=h7274673_9
51+
- libffi=3.3=he6710b0_2
52+
- libgcc-ng=9.3.0=h5101ec6_17
53+
- libgfortran-ng=7.5.0=ha8ba4b0_17
54+
- libgfortran4=7.5.0=ha8ba4b0_17
55+
- libgomp=9.3.0=h5101ec6_17
56+
- libidn2=2.3.1=h27cfd23_0
57+
- libopus=1.3.1=h7b6447c_0
58+
- libpng=1.6.37=hbc83047_0
59+
- libprotobuf=3.19.1=h4ff587b_0
60+
- libsodium=1.0.18=h7b6447c_0
61+
- libstdcxx-ng=9.3.0=hd4cf53a_17
62+
- libtasn1=4.16.0=h27cfd23_0
63+
- libtiff=4.2.0=h85742a9_0
64+
- libunistring=0.9.10=h27cfd23_0
65+
- libuuid=1.0.3=h7f8727e_2
66+
- libuv=1.40.0=h7b6447c_0
67+
- libvpx=1.7.0=h439df22_0
68+
- libwebp-base=1.2.0=h27cfd23_0
69+
- libxcb=1.14=h7b6447c_0
70+
- libxml2=2.9.12=h03d6c58_0
71+
- locket=0.2.1=py39h06a4308_1
72+
- lz4-c=1.9.3=h2531618_0
73+
- matplotlib=3.4.3=py39h06a4308_0
74+
- matplotlib-base=3.4.3=py39hbbc1b5f_0
75+
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
76+
- mkl=2021.2.0=h06a4308_296
77+
- mkl-service=2.3.0=py39h27cfd23_1
78+
- mkl_fft=1.3.0=py39h42c9631_2
79+
- mkl_random=1.2.1=py39ha9443f7_2
80+
- munkres=1.1.4=py_0
81+
- ncurses=6.2=he6710b0_1
82+
- nest-asyncio=1.5.1=pyhd3eb1b0_0
83+
- nettle=3.7.3=hbbd107a_1
84+
- networkx=2.6.3=pyhd3eb1b0_0
85+
- ninja=1.10.2=hff7bd54_1
86+
- numpy-base=1.20.2=py39hfae3a4d_0
87+
- olefile=0.46=py_0
88+
- openh264=2.1.0=hd408876_0
89+
- openssl=1.1.1s=h7f8727e_0
90+
- packaging=21.3=pyhd3eb1b0_0
91+
- parso=0.8.3=pyhd3eb1b0_0
92+
- partd=1.2.0=pyhd3eb1b0_0
93+
- pcre=8.45=h295c915_0
94+
- pexpect=4.8.0=pyhd3eb1b0_3
95+
- pickleshare=0.7.5=pyhd3eb1b0_1003
96+
- pillow=8.2.0=py39he98fc37_0
97+
- pip=21.1.3=py39h06a4308_0
98+
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
99+
- protobuf=3.19.1=py39h295c915_0
100+
- ptyprocess=0.7.0=pyhd3eb1b0_2
101+
- pygments=2.10.0=pyhd3eb1b0_0
102+
- pyparsing=3.0.4=pyhd3eb1b0_0
103+
- pyqt=5.9.2=py39h2531618_6
104+
- python=3.9.5=h12debd9_4
105+
- python-dateutil=2.8.2=pyhd3eb1b0_0
106+
- pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0
107+
- pywavelets=1.1.1=py39h6323ea4_4
108+
- pyyaml=6.0=py39h7f8727e_1
109+
- pyzmq=22.3.0=py39h295c915_2
110+
- qt=5.9.7=h5867ecd_1
111+
- readline=8.1=h27cfd23_0
112+
- scikit-image=0.16.2=py39ha9443f7_0
113+
- scikit-learn=1.0.2=py39h51133e4_0
114+
- scipy=1.6.2=py39had2a1c9_1
115+
- setuptools=52.0.0=py39h06a4308_0
116+
- sip=4.19.13=py39h2531618_0
117+
- six=1.16.0=pyhd3eb1b0_0
118+
- sqlite=3.36.0=hc218d9a_0
119+
- tensorboardx=2.2=pyhd3eb1b0_0
120+
- threadpoolctl=2.2.0=pyh0d69192_0
121+
- tk=8.6.10=hbc83047_0
122+
- toolz=0.11.2=pyhd3eb1b0_0
123+
- torchmetrics=0.10.3=pyhd8ed1ab_0
124+
- torchvision=0.2.2=py_3
125+
- tornado=6.1=py39h27cfd23_0
126+
- tqdm=4.62.3=pyhd3eb1b0_1
127+
- traitlets=5.1.1=pyhd3eb1b0_0
128+
- typing_extensions=3.10.0.0=pyh06a4308_0
129+
- tzdata=2021a=h52ac0ba_0
130+
- wcwidth=0.2.5=pyhd3eb1b0_0
131+
- wheel=0.36.2=pyhd3eb1b0_0
132+
- x264=1!157.20191217=h7b6447c_0
133+
- xz=5.2.5=h7b6447c_0
134+
- yaml=0.2.5=h7b6447c_0
135+
- zeromq=4.3.4=h2531618_0
136+
- zlib=1.2.11=h7b6447c_3
137+
- zstd=1.4.9=haebb681_0
138+
- pip:
139+
- hydroerr==1.24
140+
- hydrostats==0.78
141+
- llvmlite==0.39.1
142+
- numba==0.56.4
143+
- numpy==1.22.4
144+
- pandas==1.5.2
145+
- properscoring==0.1
146+
- pytz==2022.6
147+
- torchgeometry==0.1.2
148+

0 commit comments

Comments
 (0)