Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
149 changes: 129 additions & 20 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,150 @@
import contextlib
import gc
import time
from argparse import ArgumentParser, Namespace
from typing import List, Optional

from src.pipelines import get_pipeline_class
from src.utils.arguments import parse_args
from src.utils.benchmark import benchmark_end_to_end
from src.utils.input import get_dummy_batch
from src.utils.logging import configure_logging
import torch

from src.pipeline import get_pipeline_class
from src.profile import get_profiler, logger
from src.utils import (
configure_logging,
format_mib,
format_ms,
get_dummy_batch,
log_dict,
log_rank_n,
parse_config_args,
)


def get_arg_parser() -> ArgumentParser:
parser = ArgumentParser()

# Model
parser.add_argument("--model_type")
parser.add_argument("--pretrained_config")
parser.add_argument("--pretrained_model")
parser.add_argument("--tokenizer", default="gpt2")
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument("config_args", nargs="*")

# Runtime
parser.add_argument("--pipeline_class", default="HF_Pipeline")
parser.add_argument("--device", default="cuda", type=torch.device)
parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x))
parser.add_argument("--local_rank", type=int)
parser.add_argument("--no_fast_init", dest="fast_init", action="store_false")

# Input and output
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--max_input_length", default=-1, type=int)
parser.add_argument("--max_new_tokens", default=100, type=int)

# Cleanup
parser.add_argument("--clear_every_run", action="store_true")

# Benchmark cycles
parser.add_argument("--skip", type=int, default=1)
parser.add_argument("--warmup", type=int, default=None)
parser.add_argument("--cycles", type=int, default=5)

# Profiling and logging
parser.add_argument("--max_log_outputs", default=None, type=int)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--full_trace", action="store_true")
parser.add_argument("--show_op_names", action="store_true")

return parser


def main(argv: Optional[List[str]] = None) -> None:
args = parse_args(argv=argv)
parser = get_arg_parser()
args = parser.parse_args(argv)
config_args = parse_config_args(args.config_args)
generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False}
inputs = get_dummy_batch(args.batch_size, args.max_input_length)
warmup = args.profile if args.warmup is None else args.warmup
max_log_outputs = args.batch_size if args.max_log_outputs is None else args.max_log_outputs

pipeline_class = get_pipeline_class(args.pipeline_class)
pipeline = pipeline_class(
model_type=args.model_type,
pretrained_model=args.pretrained_model,
pretrained_config=args.pretrained_config,
config_args=args.config_args,
config_args=config_args,
tokenizer=args.tokenizer,
device=args.device,
dtype=args.dtype,
fast_init=args.fast_init,
trust_remote_code=args.trust_remote_code,
)

benchmark_end_to_end(
pipeline=pipeline,
inputs=get_dummy_batch(args.batch_size, args.max_input_length),
generate_kwargs={"max_new_tokens": args.max_new_tokens, "do_sample": False},
profile=args.profile,
skip=args.skip,
warmup=args.profile if args.warmup is None else args.warmup,
cycles=args.cycles,
full_trace=args.full_trace,
show_op_names=args.show_op_names,
max_log_outputs=args.batch_size if args.max_log_outputs is None else args.max_log_outputs,
clear_every_run=args.clear_every_run,
)
all_metrics = []

if args.profile:
profiler = get_profiler(
skip=args.skip,
warmup=warmup,
cycles=args.cycles,
full_trace=args.full_trace,
show_op_names=args.show_op_names,
)
else:
profiler = contextlib.nullcontext()

benchmark_stats = {
"Model parameters": pipeline.get_num_parameters(),
"Batch size": len(inputs),
**generate_kwargs,
**pipeline.get_initialization_metrics(),
"Warmup cycles": args.skip + warmup,
"Benchmark cycles": args.cycles,
"Total cycles": args.skip + warmup + args.cycles,
}

if pipeline.device.type == "cuda":
benchmark_stats["Initial memory used"] = format_mib(torch.cuda.memory_allocated())
benchmark_stats["Initial memory reserved"] = format_mib(torch.cuda.memory_reserved())
torch.cuda.reset_peak_memory_stats()

t0 = time.perf_counter()
with profiler as p:
for step in range(args.skip + warmup + args.cycles):
if step == args.skip + warmup:
t1 = time.perf_counter()
benchmark_stats["Warmup time"] = format_ms(t1 - t0)
generated_text, metrics = pipeline(inputs, **generate_kwargs)
if args.profile:
p.step()

if step == 0:
for i, o, _ in zip(inputs, generated_text, range(max_log_outputs)):
log_rank_n(f"{'-' * 60}\nINPUT = {i}\nOUTPUT = {o}", logger.info)

if step >= args.skip + warmup:
all_metrics.append(metrics)

if args.clear_every_run:
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
if pipeline.device.type == "cuda":
benchmark_stats["Memory used"] = format_mib(torch.cuda.memory_allocated())
benchmark_stats["Memory reserved"] = format_mib(torch.cuda.memory_reserved())
benchmark_stats["Max memory used"] = format_mib(torch.cuda.max_memory_allocated())
benchmark_stats["Max memory reserved"] = format_mib(torch.cuda.max_memory_reserved())

t2 = time.perf_counter()
benchmark_stats["Benchmark time"] = format_ms(t2 - t1)
benchmark_stats["Total time"] = format_ms(t2 - t0)

if len(all_metrics) > 0:
benchmark_stats.update(pipeline.aggregate_and_format_metrics(all_metrics))

log_rank_n("*** Benchmark results:", logger.info)
log_dict(benchmark_stats, logger.info)


if __name__ == "__main__":
Expand Down
46 changes: 43 additions & 3 deletions src/pipelines/pipeline.py → src/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import contextlib
import gc
import logging
import os
import time
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch

from src.utils.fast_init import fast_init
from src.utils.logging import format_ms, log_rank_n
from src.utils.utils import parse_revision
from src.fast_init import fast_init
from src.utils import format_ms, log_rank_n, parse_revision
from transformers import (
CONFIG_MAPPING,
AutoConfig,
Expand Down Expand Up @@ -239,3 +239,43 @@ def aggregate_and_format_metrics(self, metrics: List[Dict[str, Any]]):

def get_initialization_metrics(self):
return {f"Initialization time ({key})": format_ms(value) for key, value in self.initialization_metrics.items()}


class HF_Pipeline(Pipeline):
pass


class DS_Pipeline(Pipeline):
def __init__(self, **kwargs):
import deepspeed

super().__init__(**kwargs)

if self.device != torch.device("cuda"):
raise ValueError(f"Deepspeed does not support device {self.device}")

if self.dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Deepspeed does not support dtype {self.dtype}")

if self.config.model_type not in ("bloom", "gpt2"):
raise ValueError(f"Deepspeed does not support model type {self.config.model_type}")

self.model = deepspeed.init_inference(
self.model,
mp_size=int(os.getenv("WORLD_SIZE", "1")),
# base_dir="./",
dtype=self.dtype,
replace_with_kernel_inject=True,
)


_PIPELINE_CLASS_MAP = {
"HF_Pipeline": HF_Pipeline,
"DS_Pipeline": DS_Pipeline,
}


def get_pipeline_class(name):
if name not in _PIPELINE_CLASS_MAP:
raise NotImplementedError(f"Unsupported pipeline class: {name}")
return _PIPELINE_CLASS_MAP[name]
11 changes: 0 additions & 11 deletions src/pipelines/__init__.py

This file was deleted.

28 changes: 0 additions & 28 deletions src/pipelines/ds.py

This file was deleted.

5 changes: 0 additions & 5 deletions src/pipelines/transformers.py

This file was deleted.

62 changes: 62 additions & 0 deletions src/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import contextlib
import logging
from typing import Union

import torch

from src.utils import log_rank_n


logger = logging.getLogger(__name__)


def get_trace_fn(full_trace: bool = False, show_op_names: bool = False, rank: int = -1):
def trace_fn(
p: torch.profiler.profile,
):
averages = p.key_averages()
if full_trace:
# Show every GPU op.
# Exclude CPU cuda ops to shorten the table.
events = torch.autograd.profiler.EventList(
[evt for evt in p.profiler.function_events if evt.self_cuda_time_total > 0]
)
log_rank_n(events.table(row_limit=-1, max_src_column_width=1000), logger.info, rank)

if show_op_names:
# Show non-cropped names, in the same order as in the table.
averages_sorted = torch.autograd.profiler.EventList(
sorted(averages, key=lambda evt: evt.self_cuda_time_total, reverse=True)
)
for entry in averages_sorted:
log_rank_n(entry.key, logger.info, rank)

# Try to avoid name cropping, still hard-coded to max 55 characters
log_rank_n(
averages.table(sort_by="self_cuda_time_total", row_limit=-1, max_src_column_width=1000), logger.info, rank
)

return trace_fn


def get_profiler(
skip: int,
warmup: int,
cycles: int,
full_trace: bool = False,
show_op_names: bool = False,
) -> Union[torch.profiler.profile, contextlib.nullcontext]:
schedule = torch.profiler.schedule(
# Warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
skip_first=skip,
# Warmup for the profiler
warmup=warmup,
wait=0,
active=cycles,
)
return torch.profiler.profile(
schedule=schedule,
activities=[torch.profiler.ProfilerActivity.CUDA],
on_trace_ready=get_trace_fn(full_trace, show_op_names),
)
Loading