Skip to content
Prev Previous commit
Next Next commit
Parse results and misc
  • Loading branch information
jlamypoirier committed Feb 27, 2023
commit f29cb0e89711a54069ac15a3b88ee1f7056c5992
2 changes: 0 additions & 2 deletions scripts/run_grid.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,3 @@ do
done
done
done

# PYTHONPATH=. ./scripts/run_grid.sh "1 2 4" "50 100 200" "1 50 100" ./results/mqa_small python3 src/main.py --pipeline_class=HF_Pipeline --tokenizer=bigcode/santacoder --model_type=gpt_bigcode --dtype=float32 --device=cpu --max_log_outputs=1 --cycles=1 n_positions=512 n_embd=512 n_head=8 n_inner=2048 n_layer=8
125 changes: 125 additions & 0 deletions src/parse_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import json
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional

from src.metrics import Metrics
from src.utils import parse_config_args


def get_arg_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("input_dir", type=Path)
parser.add_argument("--filter", action="append")
parser.add_argument("--column", "--col", action="append")
parser.add_argument("--table", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("-x", "--x_axis", default=Metrics.BATCH_SIZE)
parser.add_argument("-y", "--y_axis", default=Metrics.THROUGHPUT_E2E)
parser.add_argument("-z", "--z_axis")
return parser


DEFAULT_COLUMNS = (
"Setting",
Metrics.BATCH_SIZE,
Metrics.INPUT_LENGTH,
Metrics.TOKENS_SAMPLE,
Metrics.THROUGHPUT_E2E,
Metrics.LATENCY_E2E,
)


def read_data(input_file: Path):
try:
with input_file.open("r") as f:
data = json.load(f)
data = {**data["config"], **data["results"]}
except (ValueError, OSError) as e:
raise ValueError(f"Cannot parse file {input_file} ({e})")
try:
setting, bs_, bs, seq_, seq, tok_, tok = input_file.stem.rsplit("_", 6)
assert bs_ == "bs"
assert data[Metrics.BATCH_SIZE] == int(bs)
assert seq_ == "seq"
assert data[Metrics.INPUT_LENGTH] == int(seq) or int(seq) < 0
assert tok_ == "tok"
assert data[Metrics.TOKENS_SAMPLE] == int(tok)
except (ValueError, AssertionError) as e:
raise ValueError(f"Cannot parse filename {input_file} ({e})")
data["Setting"] = setting
return data


def make_table(data, cols):
from markdownTable import markdownTable

data = [Metrics.format_metrics({col: x[col] for col in cols}) for x in data]
return markdownTable(data).getMarkdown()


def parse_key(key: Optional[str]) -> Optional[str]:
if key is None:
return key
return getattr(Metrics, key.upper(), key)


def filter_data(data, filters):
if filters is None:
return data
filters = parse_config_args(filters)
filters = {parse_key(key): value for key, value in filters.items()}
filtered_data = []
for x in data:
filter = True
for key, value in filters.items():
filter = filter and x[key] == value
if filter:
filtered_data.append(x)
return filtered_data


def plot(data, x_axis, y_axis, z_axis):
import matplotlib.pyplot as plt

x_axis = parse_key(x_axis)
y_axis = parse_key(y_axis)
z_axis = parse_key(z_axis)
x = [d[x_axis] for d in data]
y = [d[y_axis] for d in data]
z = None if z_axis is None else [d[z_axis] for d in data]

fig = plt.figure()
ax = fig.add_subplot()

scatter = ax.scatter(x, y, c=z)
ax.set_xlabel(x_axis)
ax.set_ylabel(y_axis)
if z_axis is not None:
handles, labels = scatter.legend_elements()
ax.legend(handles=handles, labels=labels, title=z_axis)
fig.show()
input("Press enter to continue")


def main(argv: Optional[List[str]] = None) -> None:
parser = get_arg_parser()
args = parser.parse_args(argv)
data = [read_data(input_file) for input_file in args.input_dir.iterdir()]

data = filter_data(data, args.filter)

if len(data) == 0:
raise RuntimeError(f"No data to show.")

cols = DEFAULT_COLUMNS if args.column is None else [parse_key(col) for col in args.column]

if args.table:
print(make_table(data, cols))

if args.plot:
plot(data, args.x_axis, args.y_axis, args.z_axis)


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __call__(self, text: List[str], **generate_kwargs) -> Tuple[List[str], Dict[
output_length = output_tokens.size(1)

output_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
print("AAA", output_text)
t3 = time.perf_counter()

metrics = {
Expand Down
47 changes: 26 additions & 21 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,36 @@ def parse_revision(pretrained_model: Optional[str]) -> Tuple[Optional[str], Opti
return pretrained_model, revision


def parse_config_arg(config_arg: str) -> Tuple[str, Any]:
split_arg = [x.strip() for x in config_arg.split("=", 1)]
if len(split_arg) != 2:
raise ValueError(f"Cannot parse argument (not in 'key=value' format): {config_arg}")
key, value = split_arg
if not key.isidentifier():
raise ValueError(f"Invalid argument (not a python identifier): {key}")
if value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.lower() == "none":
value = None
else:
try:
value = int(value)
except ValueError:
try:
value = float(value)
except ValueError:
pass
return key, value


def parse_config_args(config_args: List[str]) -> typing.Dict[str, Any]:
parsed_config_args = {}
for config_arg in config_args:
split_arg = [x.strip() for x in config_arg.split("=", 1)]
if len(split_arg) != 2:
raise ValueError(f"Cannot parse argument (not in 'key=value' format): {config_arg}")
key, value = split_arg
if not key.isidentifier():
raise ValueError(f"Invalid argument (not a python identifier): {key}")
key, value = parse_config_arg(config_arg)
if key in parsed_config_args:
raise ValueError(f"Duplicate argument: {key}")
if value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.lower() == "none":
value = None
else:
try:
value = int(value)
except ValueError:
try:
value = float(value)
except ValueError:
pass
parsed_config_args[key] = value
return parsed_config_args

Expand Down Expand Up @@ -98,7 +103,7 @@ def get_dummy_batch(batch_size: int, max_input_length: int = -1) -> List[str]:
if max_input_length == -1:
input_sentences = copy.deepcopy(dummy_input_sentences)
else:
input_sentences = batch_size * ["Hello " * max_input_length]
input_sentences = batch_size * [" Hello" * max_input_length]

if batch_size > len(input_sentences):
input_sentences *= math.ceil(batch_size / len(input_sentences))
Expand Down