Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0f82489
Added functionality for merging of common and specific templates
theory-in-progress Jun 17, 2023
b3a6a3f
Merge common and Specific code-templates
theory-in-progress Jun 20, 2023
8f67e46
Handle error code-tag is present in a file but file is missing in Com…
theory-in-progress Jun 20, 2023
bc688be
Update branch with changes from main
theory-in-progress Jun 21, 2023
83f4051
Merge branch 'pytorch-ignite:main' into merge-common-specific
theory-in-progress Jun 21, 2023
463d004
Code format
theory-in-progress Jun 21, 2023
b044f5e
Changed code tag to `from_template_common`
theory-in-progress Jun 21, 2023
bb04373
Refactoring the redundant/repeating code in utils.py
theory-in-progress Jun 21, 2023
5ed9d8c
Fix lint of utils.py
theory-in-progress Jun 21, 2023
c3d1eed
Refactoring the redundant/repeating code in config.yaml
theory-in-progress Jun 21, 2023
e474921
Refactoring the redundant/repeating code in main.py
theory-in-progress Jun 21, 2023
ac621bf
MOdify lint options
theory-in-progress Jun 21, 2023
6d9415f
Deleting the script check_copies.py and the command in workflow
theory-in-progress Jun 22, 2023
d9888dd
Add lint in tests and modify min_lint in lint
theory-in-progress Jun 22, 2023
7703a72
Merge branch 'main' into merge-common-specific
vfdev-5 Jun 22, 2023
c5c9ae0
Modify render with replace using js on vision-classification template
theory-in-progress Jun 22, 2023
bb5b106
Merge branch 'merge-common-specific' of github.com:theory-in-progress…
theory-in-progress Jun 22, 2023
4f965b1
Install formatting tools in tests job
theory-in-progress Jun 22, 2023
da90ef3
Change code tags from `#:::- from_template_common :::#` to `#::= from…
theory-in-progress Jun 22, 2023
23e2385
Modifying trainers.py to include the if else statements for unused im…
theory-in-progress Jun 22, 2023
d738012
Removes any usort skip statements in the final code
theory-in-progress Jun 23, 2023
ff45072
Added usort skip statements for certain imports in trainers.py
theory-in-progress Jun 23, 2023
38c0f80
Modified tests, imports, workflow
theory-in-progress Jun 23, 2023
733b9ea
Added type hints for functions in template-vision-segmentation/traine…
theory-in-progress Jun 23, 2023
eb75d1d
Formatting modifications
theory-in-progress Jun 23, 2023
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
pip uninstall -y tqdm
npm install -g pnpm
pnpm i --frozen-lockfile --color
bash scripts/run_code_style.sh install
# Show all installed dependencies
pip list
Expand All @@ -81,6 +82,7 @@ jobs:
- run: pnpm build
- run: pnpm test:ci
- run: sh ./scripts/run_tests.sh unzip
- run: pnpm lint

- name: 'Run ${{ matrix.template }} ${{ matrix.test }}'
run: sh ./scripts/run_tests.sh ${{ matrix.test }} ${{ matrix.template }}
Expand Down Expand Up @@ -108,5 +110,4 @@ jobs:
- run: pip install -Uq pip wheel && bash scripts/run_code_style.sh install
- run: npm install -g pnpm
- run: pnpm i --frozen-lockfile --color
- run: pnpm lint
- run: python scripts/check_copies.py
- run: pnpm min_lint
1 change: 1 addition & 0 deletions .prettierignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
dist
pnpm-lock.yaml
**/__DEV_CONFIG__.json
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
"name": "@pytorch-ignite/code-generator",
"version": "0.3.0",
"scripts": {
"dev": "vite",
"dev": "vite --port 5000",
"build": "vite build",
"serve": "vite preview",
"test": "jest --color --runInBand",
"test": "rm -rf ./dist-tests && jest --color --runInBand",
"test:ci": "start-server-and-test --expect 200 serve http://127.0.0.1:5000 test",
"release": "node scripts/release.js",
"fmt": "prettier --write . && bash scripts/run_code_style.sh fmt",
"min_lint": "prettier --check . && bash scripts/run_code_style.sh min_lint",
"lint": "prettier --check . && bash scripts/run_code_style.sh lint"
},
"dependencies": {
Expand Down
42 changes: 0 additions & 42 deletions scripts/check_copies.py

This file was deleted.

6 changes: 5 additions & 1 deletion scripts/run_code_style.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
set -xeu

if [ $1 == "lint" ]; then
# Check that ./dist-tests/ exists and code is unzipped
ls ./dist-tests/vision-classification-all/main.py
ufmt diff .
flake8 --select F401,F821 ./dist-tests # find unused imports and non imported objects
elif [ $1 == "min_lint" ]; then
ufmt diff .
flake8 --select F401 . # find unused imports
elif [ $1 == "fmt" ]; then
ufmt format .
elif [ $1 == "install" ]; then
Expand Down
19 changes: 18 additions & 1 deletion src/store.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ export function saveConfig(key, value) {
}
}

// merges the code from the common and specific files using ejs
function mergeCode(specificFileText, commonFileText) {
const replaced = specificFileText.replace(
/#::= from_template_common ::#/g,
commonFileText
)
return replaced
}

// render the code if there are fetched files for current selected template
export function genCode() {
const currentFiles = files[store.config.template]
Expand All @@ -78,6 +87,7 @@ export function genCode() {
)
// trim ` #`
.replace(/\s{4}#$/gim, '')
.replace(/ # usort: skip/g, '')
}
if (isDev) {
store.code[__DEV_CONFIG_FILE__] =
Expand All @@ -98,7 +108,14 @@ export async function fetchTemplates(template) {
files[template] = {}
for (const filename of templates[template]) {
const response = await fetch(`${url}/${template}/${filename}`)
files[template][filename] = await response.text()
const text_specific = await response.text()
// Dynamically fetch the common templates-code, if the file exists in common,
// then render the replace_here code tag using ejs template
// If the file doesn't exist in common, then it will fetch an empty string
// then the code tag is replaced with empty string
const res_common = await fetch(`${url}/template-common/${filename}`)
const text_common = await res_common.text()
files[template][filename] = mergeCode(text_specific, text_common)
}

// calling genCode explicitly here
Expand Down
27 changes: 4 additions & 23 deletions src/templates/template-common/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
ckpt_handler_train, ckpt_handler_eval = setup_handlers(
trainer, evaluator, config, to_save_train, to_save_eval
)

#::: if (it.logger) { :::#
if rank == 0:
exp_logger.close()
#::: } :::#

#::: if (it.save_training || it.save_evaluation) { :::#
# show last checkpoint names
logger.info(
"Last training checkpoint name - %s",
ckpt_handler_train.last_checkpoint,
)

logger.info(
"Last evaluation checkpoint name - %s",
ckpt_handler_eval.last_checkpoint,
)
#::: } :::#


# main entrypoint
def main():
config = setup_config()
Expand All @@ -42,3 +19,7 @@ def main():
with idist.Parallel(config.backend) as p:
p.run(run, config=config)
#::: } :::#


if __name__ == "__main__":
main()
91 changes: 24 additions & 67 deletions src/templates/template-common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,34 @@
import yaml
from ignite.contrib.engines import common
from ignite.engine import Engine

#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::#
from ignite.engine.events import Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

#::: } :::#
#::: if (it.save_training || it.save_evaluation) { :::#
from ignite.handlers import (
Checkpoint,
DiskSaver,
global_step_from_engine,
) # usort: skip

#::: } else { :::#
from ignite.handlers import Checkpoint

#::: } :::#
#::: if (it.patience) { :::#
from ignite.handlers.early_stopping import EarlyStopping

#::: } :::#
#::: if (it.terminate_on_nan) { :::#
from ignite.handlers.terminate_on_nan import TerminateOnNan

#::: } :::#
#::: if (it.limit_sec) { :::#
from ignite.handlers.time_limit import TimeLimit

#::: } :::#
from ignite.utils import setup_logger


Expand Down Expand Up @@ -141,72 +164,6 @@ def setup_logging(config: Any) -> Logger:
return logger


#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::#


def setup_handlers(
trainer: Engine,
evaluator: Engine,
config: Any,
to_save_train: Optional[dict] = None,
to_save_eval: Optional[dict] = None,
):
"""Setup Ignite handlers."""

ckpt_handler_train = ckpt_handler_eval = None
#::: if (it.save_training || it.save_evaluation) { :::#
# checkpointing
saver = DiskSaver(config.output_dir / "checkpoints", require_empty=False)
#::: if (it.save_training) { :::#
ckpt_handler_train = Checkpoint(
to_save_train,
saver,
filename_prefix=config.filename_prefix,
n_saved=config.n_saved,
)
trainer.add_event_handler(
Events.ITERATION_COMPLETED(every=config.save_every_iters),
ckpt_handler_train,
)
#::: } :::#
#::: if (it.save_evaluation) { :::#
global_step_transform = None
if to_save_train.get("trainer", None) is not None:
global_step_transform = global_step_from_engine(to_save_train["trainer"])
ckpt_handler_eval = Checkpoint(
to_save_eval,
saver,
filename_prefix="best",
n_saved=config.n_saved,
global_step_transform=global_step_transform,
)
evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), ckpt_handler_eval)
#::: } :::#
#::: } :::#

#::: if (it.patience) { :::#
# early stopping

es = EarlyStopping(config.patience, score_fn, trainer)
evaluator.add_event_handler(Events.EPOCH_COMPLETED, es)
#::: } :::#

#::: if (it.terminate_on_nan) { :::#
# terminate on nan
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
#::: } :::#

#::: if (it.limit_sec) { :::#
# time limit
trainer.add_event_handler(Events.ITERATION_COMPLETED, TimeLimit(config.limit_sec))
#::: } :::#
#::: if (it.save_training || it.save_evaluation) { :::#
return ckpt_handler_train, ckpt_handler_eval
#::: } :::#


#::: } :::#

#::: if (it.logger) { :::#


Expand Down
53 changes: 1 addition & 52 deletions src/templates/template-text-classification/config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
seed: 666
data_path: ./
train_batch_size: 32
eval_batch_size: 32
num_workers: 4
max_epochs: 20
train_epoch_length: 1000
eval_epoch_length: 1000
use_amp: false
debug: false
#::= from_template_common ::#
model: bert-base-uncased
model_dir: /tmp/model
tokenizer_dir: /tmp/tokenizer
Expand All @@ -18,45 +9,3 @@ weight_decay: 0.01
num_warmup_epochs: 0
max_length: 256
lr: 0.00005

#::: if (it.dist === 'spawn') { :::#
# distributed spawn
nproc_per_node: #:::= it.nproc_per_node :::#
#::: if (it.nnodes) { :::#
# distributed multi node spawn
nnodes: #:::= it.nnodes :::#
#::: if (it.nnodes > 1) { :::#
node_rank: 0
master_addr: #:::= it.master_addr :::#
master_port: #:::= it.master_port :::#
#::: } :::#
#::: } :::#
#::: } :::#

#::: if (it.filename_prefix) { :::#
filename_prefix: #:::= it.filename_prefix :::#
#::: } :::#

#::: if (it.n_saved) { :::#
n_saved: #:::= it.n_saved :::#
#::: } :::#

#::: if (it.save_every_iters) { :::#
save_every_iters: #:::= it.save_every_iters :::#
#::: } :::#

#::: if (it.patience) { :::#
patience: #:::= it.patience :::#
#::: } :::#

#::: if (it.limit_sec) { :::#
limit_sec: #:::= it.limit_sec :::#
#::: } :::#

#::: if (it.output_dir) { :::#
output_dir: #:::= it.output_dir :::#
#::: } :::#

#::: if (it.log_every_iters) { :::#
log_every_iters: #:::= it.log_every_iters :::#
#::: } :::#
26 changes: 1 addition & 25 deletions src/templates/template-text-classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,28 +173,4 @@ def _():
#::: } :::#


# main entrypoint
def main():
config = setup_config()
#::: if (it.dist === 'spawn') { :::#
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
kwargs = {
"nproc_per_node": config.nproc_per_node,
"nnodes": config.nnodes,
"node_rank": config.node_rank,
"master_addr": config.master_addr,
"master_port": config.master_port,
}
#::: } else if (it.nproc_per_node) { :::#
kwargs = {"nproc_per_node": config.nproc_per_node}
#::: } :::#
with idist.Parallel(config.backend, **kwargs) as p:
p.run(run, config=config)
#::: } else { :::#
with idist.Parallel(config.backend) as p:
p.run(run, config=config)
#::: } :::#


if __name__ == "__main__":
main()
#::= from_template_common ::#
Loading