Skip to content
Prev Previous commit
Next Next commit
Support pretrained models
  • Loading branch information
jlamypoirier committed Feb 8, 2023
commit 7b074a845062d024d6c21e37fa76af5a6fb8e549
6 changes: 6 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ ENV PYTHONPATH=/app
RUN useradd -m -u $USER -s /bin/bash $USERNAME \
&& chown $USERNAME /app

# git-lfs is needed to interact with the huggingface hub
RUN apt-get update \
&& apt-get install git-lfs \
&& rm -rf /var/lib/apt/lists/* \
&& git lfs install

COPY --chown=$USERNAME ./requirements.txt ./
COPY --chown=$USERNAME transformers/ ./transformers

Expand Down
1 change: 0 additions & 1 deletion src/pipelines/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self, args: Namespace) -> None:
check_unused(args, {"dtype": torch.float16})
super().__init__(args)


self.model = deepspeed.init_inference(
self.model,
mp_size=int(os.getenv("WORLD_SIZE", "1")),
Expand Down
113 changes: 69 additions & 44 deletions src/pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import gc
import logging
import time
Expand All @@ -10,7 +11,14 @@
from src.utils.arguments import check_unused
from src.utils.fast_init import fast_init
from src.utils.logging import format_ms, log_rank_n
from transformers import AutoTokenizer, BloomForCausalLM, GPT2LMHeadModel, PretrainedConfig, PreTrainedModel, GPTBigCodeLMHeadModel
from transformers import (
AutoTokenizer,
BloomForCausalLM,
GPT2LMHeadModel,
GPTBigCodeLMHeadModel,
PretrainedConfig,
PreTrainedModel,
)


logger = logging.getLogger(__name__)
Expand All @@ -32,73 +40,90 @@

class Pipeline:
def __init__(self, args: Namespace) -> None:
self.args = args
log_rank_n("*** Setting up tokenizer", logger.info)
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

self.device = args.device

model_class, config = self.get_config(args)
is_int8 = args.dtype == torch.int8
if is_int8:
self.is_int8 = args.dtype == torch.int8
if self.is_int8:
check_unused(args, {"device": torch.device("cuda")}, enforce=True)
torch_dtype = torch.float16 if is_int8 else args.dtype
self.torch_dtype = torch.float16 if self.is_int8 else args.dtype

self.model_class, self.config = self._get_config()

pretrained_model = args.pretrained_model
if pretrained_model is None:
self.model = self._create_model(self.config)
if self.is_int8:
log_rank_n("*** Saving model", logger.info)
self.model.save_pretrained("tmp")
del self.model
gc.collect()
pretrained_model = "tmp"

if pretrained_model is not None:
self.model = self._load_pretrained(self.config, pretrained_model)

self.model.eval()

def _create_model(self, config):
log_rank_n("*** Creating model", logger.info)
with fast_init(self.device):
self.model = model_class._from_config(config=config, torch_dtype=torch_dtype)
with fast_init(self.device) if self.args.fast_init else contextlib.nullcontext():
model = self.model_class._from_config(config=config, torch_dtype=self.torch_dtype)

log_rank_n("*** Moving to device", logger.info)
self.model.to(self.device)
model.to(self.device)
log_rank_n("*** Initializing weights", logger.info)
# Initialization is ~1000x faster on GPU.
self.model.init_weights()

# Int8 can only be obtained by reloading a pretrained model
if is_int8:
log_rank_n("*** Saving model", logger.info)
self.model.save_pretrained("tmp")
self.model = None
gc.collect()
torch.cuda.empty_cache()
log_rank_n("*** Reloading model in int8", logger.info)
with fast_init(self.device):
self.model = model_class.from_pretrained(
"tmp",
load_in_8bit=True,
device_map="auto",
)

self.model.eval()

def get_config(self, args) -> Tuple[Type[PreTrainedModel], PretrainedConfig]:
model.init_weights()
return model

def _load_pretrained(self, config, pretrained_model):
with fast_init(self.device) if self.args.fast_init else contextlib.nullcontext():
return self.model_class.from_pretrained(
pretrained_model,
config=config,
load_in_8bit=True,
device_map="auto",
)

def _get_config(self) -> Tuple[Type[PreTrainedModel], PretrainedConfig]:
config_args = {
"activation_function": args.activation_function,
"n_head": args.n_head,
"n_layer": args.n_layer,
"activation_function": self.args.activation_function,
"n_head": self.args.n_head,
"n_layer": self.args.n_layer,
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"vocab_size": len(self.tokenizer),
"use_cache": True,
}
if args.model_class.lower() == "bloom":
check_unused(args, {"attention_type": 1, "n_positions": None})
if self.args.model_class.lower() == "bloom":
check_unused(self.args, {"attention_type": None, "n_positions": None})
config_args["attention_softmax_in_fp32"] = True
config_args["hidden_size"] = args.hidden_size
config_args["hidden_size"] = self.args.hidden_size
model_class = BloomForCausalLM
elif args.model_class.lower() == "gpt2":
check_unused(args, {"attention_type": 1})
config_args["n_embd"] = args.hidden_size
config_args["n_positions"] = args.n_positions
elif self.args.model_class.lower() == "gpt2":
check_unused(self.args, {"attention_type": None})
config_args["n_embd"] = self.args.hidden_size
config_args["n_positions"] = self.args.n_positions
model_class = GPT2LMHeadModel
elif args.model_class.lower() == "gpt_bigcode":
#config_args["attention_type"] = args.attention_type
config_args["n_embd"] = args.hidden_size
config_args["n_positions"] = args.n_positions
elif self.args.model_class.lower() == "gpt_bigcode":
config_args["attention_type"] = self.args.attention_type
config_args["n_embd"] = self.args.hidden_size
config_args["n_positions"] = self.args.n_positions
model_class = GPTBigCodeLMHeadModel
else:
raise NotImplementedError()
# Use defaults or pretrained config for missing arguments.
config_args = {key: value for key, value in config_args.items() if value is not None}
if self.args.pretrained_model is None:
config = model_class.config_class(**config_args)
else:
config = model_class.config_class.from_pretrained(self.args.pretrained_model, **config_args)

return model_class, model_class.config_class(**config_args)
return model_class, config

def __call__(self, text: List[str], **generate_kwargs) -> Tuple[List[str], Dict[str, Any]]:
t0 = time.perf_counter()
Expand Down
2 changes: 2 additions & 0 deletions src/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def get_arg_parser() -> ArgumentParser:

# Model
parser.add_argument("--model_class", default="GPT2", type=str)
parser.add_argument("--pretrained_model")
parser.add_argument("--hidden_size", type=int)
parser.add_argument("--attention_type", type=int)
parser.add_argument("--n_positions", type=int)
Expand All @@ -22,6 +23,7 @@ def get_arg_parser() -> ArgumentParser:
parser.add_argument("--device", default="cuda", type=torch.device)
parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x))
parser.add_argument("--local_rank", type=int)
parser.add_argument("--no_fast_init", dest="fast_init", action="store_false")

# Input and output
parser.add_argument("--batch_size", default=1, type=int)
Expand Down