Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions app/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,35 @@


class CodeGenerator:
def __init__(self, templates_dir: str = "./templates", target_dir: str = "./dist"):
def __init__(self, templates_dir: str = "./templates", dist_dir: str = "./dist"):
self.templates_dir = Path(templates_dir)
self.target_dir = Path(target_dir)
self.template_list = [p.stem for p in self.templates_dir.iterdir() if p.is_dir()]
self.env = Environment(loader=FileSystemLoader(self.templates_dir), trim_blocks=True, lstrip_blocks=True)
self.dist_dir = Path(dist_dir)
self.template_list = [p.stem for p in self.templates_dir.iterdir() if p.is_dir() and not p.stem.startswith("_")]
self.rendered_code = {t: {} for t in self.template_list}
self.available_archive_formats = sorted(map(lambda x: x[0], shutil.get_archive_formats()), reverse=True)

def render_template(self, template_name: str, fname: str, config: dict):
"""Renders single template file `fname` of the template `template_name`."""
# Get template
template = self.env.get_template(fname)
# Render template
code = template.render(**config)
# Store rendered code
fname = fname.replace(".jinja", "").replace(f"{template_name}/", "")
self.rendered_code[template_name][fname] = code
return fname, code

def render_templates(self, template_name: str, config: dict):
"""Renders all the templates files from template folder for the given config."""
self.rendered_code[template_name] = {}
for fname in filter(lambda t: t.startswith(template_name), self.env.list_templates(".jinja")):
fname, code = self.render_template(template_name, fname, config)
self.rendered_code[template_name] = {} # clean up the rendered code for given template
# loading the template files based on given template
env = Environment(
loader=FileSystemLoader(self.templates_dir / template_name),
trim_blocks=True,
lstrip_blocks=True,
)
for fname in env.list_templates(filter_func=lambda x: not x.startswith("_")):
code = env.get_template(fname).render(**config)
fname = fname.replace(".pyi", ".py")
self.rendered_code[template_name][fname] = code
yield fname, code

def create_target_template_dir(self, template_name: str):
self.target_template_path = Path(f"{self.target_dir}/{template_name}")
self.target_template_path.mkdir(parents=True, exist_ok=True)
def mk_dist_template_dir(self, template_name: str):
self.dist_template_dir = Path(f"{self.dist_dir}/{template_name}")
self.dist_template_dir.mkdir(parents=True, exist_ok=True)

def write_file(self, fname: str, code: str) -> None:
"""Creates `fname` with content `code` in `target_dir/template_name`."""
(self.target_template_path / fname).write_text(code)
"""Creates `fname` with content `code` in `dist_dir/template_name`."""
(self.dist_template_dir / fname).write_text(code)

def write_files(self, template_name):
"""Writes all rendered code for the specified template."""
Expand All @@ -48,12 +44,12 @@ def write_files(self, template_name):
self.write_file(fname, code)

def make_archive(self, template_name, archive_format):
"""Creates target dir with generated code, then makes the archive."""
self.create_target_template_dir(template_name)
"""Creates dist dir with generated code, then makes the archive."""
self.mk_dist_template_dir(template_name)
self.write_files(template_name)
archive_fname = shutil.make_archive(
base_name=str(self.target_template_path),
base_name=str(self.dist_template_dir),
format=archive_format,
base_dir=self.target_template_path,
base_dir=self.dist_template_dir,
)
return archive_fname
2 changes: 1 addition & 1 deletion app/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def render_code(self, fname="", code=""):

def add_sidebar(self):
def config(template_name):
return import_from_file("template_config", f"./templates/{template_name}/sidebar.py")
return import_from_file("template_config", f"./templates/{template_name}/_sidebar.py")

self.sidebar(self.codegen.template_list, config)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.black]
line-length = 120
target-version = ['py38']
include = '\.pyi?$'
include = '\.py?$'
exclude = '''

(
Expand All @@ -25,3 +25,4 @@ exclude = '''
[tool.isort]
profile = "black"
multi_line_output = 3
supported_extensions = "py"
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
88 changes: 88 additions & 0 deletions templates/gan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# GAN Template by Code-Generator

This template is ported from [PyTorch-Ignite DCGAN example](https://github.com/pytorch/ignite/tree/master/examples/gan).

After downloading the archive, install the requirements with:

```sh
pip install -r requirements.txt -U --progress-bar off
```

The requirements are:

- Pandas
- PyTorch
- Matplotlib
- Torchvision
- PyTorch-Ignite

After installing the requirements, run the training with:

```sh
python main.py --verbose
```

The following options are available to configure (`python main.py -h`):

```sh
usage: main.py [-h] [--batch_size BATCH_SIZE] [--data_path DATA_PATH]
[--filepath FILEPATH] [--num_workers NUM_WORKERS]
[--max_epochs MAX_EPOCHS] [--epoch_length EPOCH_LENGTH]
[--lr LR] [--log_train LOG_TRAIN] [--seed SEED] [--verbose]
[--nproc_per_node NPROC_PER_NODE] [--nnodes NNODES]
[--node_rank NODE_RANK] [--master_addr MASTER_ADDR]
[--master_port MASTER_PORT] [--n_saved N_SAVED]
[--dataset {cifar10,lsun,imagenet,folder,lfw,fake,mnist}]
[--z_dim Z_DIM] [--alpha ALPHA] [--g_filters G_FILTERS]
[--d_filters D_FILTERS] [--beta_1 BETA_1] [--saved_G SAVED_G]
[--saved_D SAVED_D]

optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
will be equally divided by number of GPUs if in
distributed (4)
--data_path DATA_PATH
datasets path (./)
--filepath FILEPATH logging file path (./logs)
--num_workers NUM_WORKERS
num_workers for DataLoader (2)
--max_epochs MAX_EPOCHS
max_epochs of ignite.Engine.run() for training (2)
--epoch_length EPOCH_LENGTH
epoch_length of ignite.Engine.run() for training
(None)
--lr LR learning rate used by torch.optim.* (0.001)
--log_train LOG_TRAIN
logging interval of training iteration (50)
--seed SEED used in ignite.utils.manual_seed() (666)
--verbose use logging.INFO in ignite.utils.setup_logger
--nproc_per_node NPROC_PER_NODE
number of processes to launch on each node, for GPU
training this is recommended to be set to the number
of GPUs in your system so that each process can be
bound to a single GPU (1)
--nnodes NNODES number of nodes to use for distributed training (1)
--node_rank NODE_RANK
rank of the node for multi-node distributed training
(None)
--master_addr MASTER_ADDR
master node TCP/IP address for torch native backends
(None)
--master_port MASTER_PORT
master node port for torch native backends None
--n_saved N_SAVED number of best models to store (2)
--dataset {cifar10,lsun,imagenet,folder,lfw,fake,mnist}
dataset to use (cifar10)
--z_dim Z_DIM size of the latent z vector (100)
--alpha ALPHA running average decay factor (0.98)
--g_filters G_FILTERS
number of filters in the second-to-last generator
deconv layer (64)
--d_filters D_FILTERS
number of filters in first discriminator conv layer
(64)
--beta_1 BETA_1 beta_1 for Adam optimizer (0.5)
--saved_G SAVED_G path to saved generator (None)
--saved_D SAVED_D path to saved discriminator (None)
```
File renamed without changes.
95 changes: 0 additions & 95 deletions templates/gan/config.toml

This file was deleted.

File renamed without changes.
File renamed without changes.
17 changes: 7 additions & 10 deletions templates/gan/main.py.jinja → templates/gan/main.pyi
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
from argparse import ArgumentParser
from pathlib import Path
import logging
import os
from typing import Any
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import Any

import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.optim as optim

import ignite.distributed as idist
from datasets import get_datasets
from fn import update
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage
from ignite.utils import manual_seed
from models import Discriminator, Generator
from torchvision import utils as vutils

from models import Generator, Discriminator
from datasets import get_datasets
from fn import update
from utils import get_default_parser


PRINT_FREQ = 100
FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png"
REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png"
Expand Down
File renamed without changes.
5 changes: 5 additions & 0 deletions templates/gan/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch
torchvision
pytorch-ignite
matplotlib
pandas
1 change: 1 addition & 0 deletions templates/gan/utils.py.jinja → templates/gan/utils.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{% block imports %}
from argparse import ArgumentParser

{% endblock %}

{% block get_default_parser %}
Expand Down
Loading