Skip to content
Prev Previous commit
Next Next commit
fix: use pathlib / instead os.path.join
  • Loading branch information
ydcjeff committed Mar 23, 2021
commit 7663c9ddb893c6fb13ac448dff2187967006219d
12 changes: 6 additions & 6 deletions templates/gan/main.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run(

@trainer.on(Events.ITERATION_COMPLETED(every=config.log_train))
def print_logs(engine):
fname = os.path.join(config.filepath, LOGS_FNAME)
fname = config.filepath / LOGS_FNAME
columns = ["iteration",] + list(engine.state.metrics.keys())
values = [str(engine.state.iteration),] + [str(round(value, 5)) for value in engine.state.metrics.values()]

Expand All @@ -118,7 +118,7 @@ def run(
@trainer.on(Events.EPOCH_COMPLETED)
def save_fake_example(engine):
fake = netG(fixed_noise)
path = os.path.join(config.filepath, FAKE_IMG_FNAME.format(engine.state.epoch))
path = config.filepath / (FAKE_IMG_FNAME.format(engine.state.epoch))
vutils.save_image(fake.detach(), path, normalize=True)

# --------------------------------------------------
Expand All @@ -127,7 +127,7 @@ def run(
@trainer.on(Events.EPOCH_COMPLETED)
def save_real_example(engine):
img, y = engine.state.batch
path = os.path.join(config.filepath, REAL_IMG_FNAME.format(engine.state.epoch))
path = config.filepath / (REAL_IMG_FNAME.format(engine.state.epoch))
vutils.save_image(img, path, normalize=True)

# -------------------------------------------------------------
Expand Down Expand Up @@ -173,11 +173,11 @@ def run(
warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

else:
df = pd.read_csv(os.path.join(config.filepath, LOGS_FNAME), delimiter="\t", index_col="iteration")
df = pd.read_csv(config.filepath / LOGS_FNAME, delimiter="\t", index_col="iteration")
_ = df.plot(subplots=True, figsize=(20, 20))
_ = plt.xlabel("Iteration number")
fig = plt.gcf()
path = os.path.join(config.filepath, PLOT_FNAME)
path = config.filepath / PLOT_FNAME

fig.savefig(path)

Expand All @@ -199,7 +199,7 @@ def run(
# ---------------------------------------------
# Setup is done. Now let's run the training
# ---------------------------------------------
trainer.run(loader, config.max_epochs, config.epoch_length)
trainer.run(loader, max_epochs=config.max_epochs, epoch_length=config.epoch_length)


{% block main_fn %}
Expand Down