Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
812a72c
refactor
mayank31398 Oct 17, 2022
21713ac
refactor
mayank31398 Oct 17, 2022
0ea738a
refactor
mayank31398 Oct 17, 2022
6239fc6
refactor
mayank31398 Oct 17, 2022
01e9515
refactor
mayank31398 Oct 17, 2022
b5a29b8
test
mayank31398 Dec 1, 2022
e4a29b5
test
mayank31398 Dec 1, 2022
48f0aa0
test
mayank31398 Dec 1, 2022
646b63b
test
mayank31398 Dec 1, 2022
3281d16
test
mayank31398 Dec 1, 2022
2fbb6c3
test
mayank31398 Dec 1, 2022
1090704
test
mayank31398 Dec 1, 2022
ef8ec7c
test
mayank31398 Dec 1, 2022
b94ea81
fp32, bf16, int8
mayank31398 Dec 3, 2022
17534fd
fp32, bf16, int8
mayank31398 Dec 3, 2022
e7230b5
fp32, bf16, int8
mayank31398 Dec 3, 2022
38c616b
use_cache
mayank31398 Dec 3, 2022
15a2c80
use_cache
mayank31398 Dec 3, 2022
80ba9bb
gc
mayank31398 Dec 3, 2022
f28f8ac
benchmark
mayank31398 Dec 3, 2022
d04dc14
benchmark
mayank31398 Dec 4, 2022
9dc5268
benchmark
mayank31398 Dec 4, 2022
23a5eb1
fix
mayank31398 Dec 4, 2022
391e055
fix
mayank31398 Dec 4, 2022
856c77b
fix
mayank31398 Dec 4, 2022
9d99f46
fp32
mayank31398 Dec 4, 2022
dfe8cb3
bf16
mayank31398 Dec 4, 2022
7344ae0
bf16
mayank31398 Dec 4, 2022
a4c3b81
ds-inference
mayank31398 Dec 4, 2022
a0f308d
device map
mayank31398 Dec 4, 2022
0947688
device map
mayank31398 Dec 4, 2022
379bfd9
fix
mayank31398 Dec 4, 2022
6dc0c07
fp32
mayank31398 Dec 5, 2022
7dc67ea
bf16
mayank31398 Dec 5, 2022
2ac761d
int8
mayank31398 Dec 5, 2022
28e1e71
attention_type
mayank31398 Dec 5, 2022
b2c7de7
fp32
mayank31398 Dec 5, 2022
76b3b8d
bf16
mayank31398 Dec 6, 2022
c149ee9
fp32
mayank31398 Dec 6, 2022
8427b94
int8
mayank31398 Dec 6, 2022
487954f
fp16
mayank31398 Dec 6, 2022
0253839
total params
mayank31398 Dec 6, 2022
893c521
models
mayank31398 Dec 6, 2022
daea92d
Add code to vary input length (#5)
minimario Dec 6, 2022
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test
  • Loading branch information
mayank31398 committed Dec 1, 2022
commit 1090704ebccd555be42c5bcc36673d7d8e438ae0
13 changes: 12 additions & 1 deletion src/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,21 @@ def get_config_tokenizer_model_class(args: Namespace) -> Union[BloomConfig, GPT2
n_positions=args.n_positions,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
attention_type=args.attention_type,
attention_type=get_attention_type(args.attention_type),
print_details=False,
vocab_size=len(tokenizer),
)
model_class = GPT2LMHeadModel

return config, tokenizer, model_class


def get_attention_type(attention_type: int):
from transformers.models.gpt2.modeling_gpt2 import AttentionType

if attention_type == 1:
return AttentionType.MULTI_HEAD
elif attention_type == 2:
return AttentionType.MULTI_QUERY
elif attention_type == 3:
return AttentionType.MULTI_QUERY_1