Skip to content
Prev Previous commit
Next Next commit
fix
  • Loading branch information
jlamypoirier committed Feb 15, 2023
commit 109d396271f07a5e341e790055f2ee9eb3e122ed
4 changes: 2 additions & 2 deletions src/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def _save_pretrained(self, pretrained_model: str):
def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
t0 = time.perf_counter()
log_rank_n(f"*** Loading model from {pretrained_model}", logger.info)
kwargs = {"load_in_8bit": True, "device_map": "auto"} if self.is_int8 else {"torch_dtype": self.dtype}
with fast_init(self.device) if self.fast_init else contextlib.nullcontext():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model,
config=self.config,
load_in_8bit=self.is_int8,
device_map="auto" if self.is_int8 else None,
**kwargs,
)
t1 = time.perf_counter()
self.initialization_metrics["load pretrained model"] = t1 - t0
Expand Down