Skip to content

Commit c809703

Browse files
author
Xiaoming Zhao
committed
Align Generator selection (#3); Fix doc's typos.
1 parent 566e9c5 commit c809703

File tree

6 files changed

+52
-31
lines changed

6 files changed

+52
-31
lines changed

docs/TRAIN_EVAL.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ If everyting goes well, you should observe the following folder structure:
125125
```
126126
.
127127
+-- ckpts
128-
| +-- afhqcat.pkl # file
129-
| +-- metfaces.pkl # file
130128
| +-- stylegan2_pretrained # folder
129+
| | +-- afhqcat.pkl # file
130+
| | +-- metfaces.pkl # file
131131
| | +-- transfer-learning-source-nets # folder
132-
+-- dataset
132+
+-- runtime_dataset
133133
| +-- ffhq256x256.zip # file
134134
| +-- ffhq256_deep3dface_coeffs # folder
135135
| +-- ffhq512x512.zip # file
@@ -179,7 +179,7 @@ The command to evaluate the trained model is in [eval.sh](../gmpi/eval/eval.sh).
179179

180180
Run the following command to evalute the model:
181181
```bash
182-
bash ${GMPI_ROOT}/eval/eval.sh \
182+
bash ${GMPI_ROOT}/gmpi/eval/eval.sh \
183183
${GMPI_ROOT} \
184184
FFHQ512 \ # this can be FFHQ256, FFHQ512, FFHQ1024, AFHQCat, or MetFaces
185185
exp_id \ # this is your experiment ID

gmpi/eval/common.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,38 @@ def setup_model(opt, config, metadata, mpi_xyz_input, mpi_xyz_only_z, vis_mesh=F
2222
n_g_out_channels = 4
2323
n_g_out_planes = opt.nplanes
2424

25-
if "depth2alpha_n_z_bins" in config.GMPI.MPI and config.GMPI.MPI.depth2alpha_n_z_bins is not None:
26-
from gmpi.models.networks.networks_vanilla_depth2alpha import Generator as StyleGAN2Generator
27-
28-
if config.GMPI.TRAIN.normalized_xyz_range == "01":
29-
depth2alpha_z_range = 1.0
30-
elif config.GMPI.TRAIN.normalized_xyz_range == "-11":
31-
depth2alpha_z_range = 2.0
32-
else:
33-
raise ValueError
34-
else:
25+
if config.GMPI.TRAIN.normalized_xyz_range == "01":
3526
depth2alpha_z_range = 1.0
36-
37-
if "depth2alpha_n_z_bins" not in config.GMPI.MPI:
38-
config.defrost()
39-
config.GMPI.MPI.depth2alpha_n_z_bins = None
40-
config.freeze()
41-
42-
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc != "none":
27+
elif config.GMPI.TRAIN.normalized_xyz_range == "-11":
28+
depth2alpha_z_range = 2.0
29+
else:
30+
raise ValueError
31+
32+
if "depth2alpha_n_z_bins" not in config.GMPI.MPI:
33+
config.defrost()
34+
config.GMPI.MPI.depth2alpha_n_z_bins = None
35+
config.freeze()
36+
37+
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc != "none":
38+
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc == "depth2alpha":
39+
print("\nGenerator comes from depth2alpha\n")
40+
from gmpi.models.networks.networks_vanilla_depth2alpha import Generator as StyleGAN2Generator
41+
elif config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc == "normalize_add_z":
4342
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc_embed_func in ["learnable_param"]:
43+
print("\nGenerator comes from learnable_param\n")
44+
n_g_out_planes = config.GMPI.MPI.n_gen_planes
4445
from gmpi.models.networks.networks_pos_enc_learnable_param import Generator as StyleGAN2Generator
45-
46-
n_g_out_planes = n_g_out_planes = config.GMPI.MPI.n_gen_planes
47-
else:
46+
elif config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc_embed_func in ["modulated_lrelu"]:
47+
print("\nGenerator comes from cond_on_depth\n")
4848
from gmpi.models.networks.networks_cond_on_pos_enc import Generator as StyleGAN2Generator
49+
else:
50+
raise NotImplementedError
4951
else:
50-
from gmpi.models.networks.networks_vanilla import Generator as StyleGAN2Generator
52+
raise NotImplementedError
53+
else:
54+
print("\nGenerator comes from vanilla\n")
55+
n_g_out_planes = config.GMPI.MPI.n_gen_planes
56+
from gmpi.models.networks.networks_vanilla import Generator as StyleGAN2Generator
5157

5258
synthesis_kwargs = convert_cfg_to_dict(config.GMPI.MODEL.STYLEGAN2.synthesis_kwargs)
5359
synthesis_kwargs_D = convert_cfg_to_dict(config.GMPI.MODEL.STYLEGAN2.synthesis_kwargs)

gmpi/eval/prepare_fake_data.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,15 @@ def main(opt):
106106
# NOTE: Deep3DFaceRecon can only provide mask and depth of 224x224
107107
# https://github.com/sicxu/Deep3DFaceRecon_pytorch
108108
metadata["img_size"] = 224
109+
110+
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc == "none":
111+
# Vanilla version. The number of planes must be same in training and evaluation.
112+
n_mpi_planes = config.GMPI.MPI.n_gen_planes
113+
else:
114+
n_mpi_planes = opt.nplanes
109115

110116
mpi_renderer = MPIRenderer(
111-
n_mpi_planes=opt.nplanes, # config.GMPI.MPI.n_gen_planes,
117+
n_mpi_planes=n_mpi_planes, # config.GMPI.MPI.n_gen_planes,
112118
plane_min_d=metadata["ray_start"],
113119
plane_max_d=metadata["ray_end"],
114120
plan_spatial_enlarge_factor=config.GMPI.MPI.CAM_SETUP.spatial_enlarge_factor,

gmpi/models/networks/networks_cond_on_pos_enc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,7 @@ def __init__(self,
12341234
gen_alpha_largest_res = 256,
12351235
G_final_img_act = "none",
12361236
depth2alpha_z_range=1.0,
1237-
depth2alpha_n_z_bins=10,
1237+
depth2alpha_n_z_bins=None,
12381238
):
12391239
super().__init__()
12401240

gmpi/models/networks/networks_vanilla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def __init__(self,
604604
self.conv_clamp = synthesis_kwargs["conv_clamp"]
605605

606606
def forward(self, z, c, mpi_xyz_coords, xyz_coords_only_z, n_planes, truncation_psi=1, truncation_cutoff=None,
607-
enable_mapping_grad=True, enable_syn_feat_net_grad=True, **synthesis_kwargs):
607+
enable_mapping_grad=True, enable_syn_feat_net_grad=True, z_interpolation_ws=None, **synthesis_kwargs):
608608

609609
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
610610
img = self.synthesis(ws, **synthesis_kwargs)

gmpi/train.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,21 @@ def train(rank, world_size, config, master_port, run_dataset):
8686
# fmt: off
8787
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc != "none":
8888
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc == "depth2alpha":
89+
print("\nGenerator comes from depth2alpha\n")
8990
from gmpi.models.networks.networks_vanilla_depth2alpha import Generator as StyleGAN2Generator
90-
elif config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc_embed_func in ["learnable_param"]:
91-
from gmpi.models.networks.networks_pos_enc_learnable_param import Generator as StyleGAN2Generator
91+
elif config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc == "normalize_add_z":
92+
if config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc_embed_func in ["learnable_param"]:
93+
print("\nGenerator comes from learnable_param\n")
94+
from gmpi.models.networks.networks_pos_enc_learnable_param import Generator as StyleGAN2Generator
95+
elif config.GMPI.MODEL.STYLEGAN2.torgba_cond_on_pos_enc_embed_func in ["modulated_lrelu"]:
96+
print("\nGenerator comes from cond_on_depth\n")
97+
from gmpi.models.networks.networks_cond_on_pos_enc import Generator as StyleGAN2Generator
98+
else:
99+
raise NotImplementedError
92100
else:
93-
from gmpi.models.networks.networks_cond_on_pos_enc import Generator as StyleGAN2Generator
101+
raise NotImplementedError
94102
else:
103+
print("\nGenerator comes from vanilla\n")
95104
from gmpi.models.networks.networks_vanilla import Generator as StyleGAN2Generator
96105
# fmt: on
97106

0 commit comments

Comments
 (0)