Skip to content

Commit 1822b22

Browse files
authored
Merge pull request #47 from ai-forever/kirillova/lita_multigpu_fix
fix: change device choice in model loader function
2 parents 4486ab7 + 49cd1e9 commit 1822b22

File tree

2 files changed

+99
-3
lines changed

2 files changed

+99
-3
lines changed

DPF/filters/videos/lita_filter.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from io import BytesIO
34
from typing import Any, Optional
45

@@ -7,17 +8,20 @@
78
from lita.constants import (
89
DEFAULT_IM_END_TOKEN,
910
DEFAULT_IM_START_TOKEN,
11+
DEFAULT_IMAGE_PATCH_TOKEN,
1012
DEFAULT_IMAGE_TOKEN,
1113
IMAGE_TOKEN_INDEX,
14+
TIME_TOKEN_TEMPLATE,
1215
)
13-
from lita.model.builder import load_pretrained_model
16+
from lita.model.language_model.lita_llama import LitaLlamaForCausalLM
1417
from lita.utils import load_video
1518
from llava.conversation import SeparatorStyle, conv_templates
1619
from llava.mm_utils import (
1720
KeywordsStoppingCriteria,
1821
get_model_name_from_path,
1922
tokenizer_image_token,
2023
)
24+
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
2125

2226
from DPF.types import ModalityToDataMapping
2327

@@ -29,6 +33,98 @@
2933
from torch.utils.data import default_collate
3034

3135

36+
def load_pretrained_model(model_path: str,
37+
model_base: str,
38+
model_name: str,
39+
load_8bit: bool = False,
40+
load_4bit: bool = False,
41+
device_map: str = "auto",
42+
device: str = "cuda"):
43+
kwargs = {"device_map": device_map}
44+
45+
if device != "cuda":
46+
kwargs['device_map'] = {"": device} # type: ignore
47+
48+
if load_8bit:
49+
kwargs['load_in_8bit'] = True # type: ignore
50+
elif load_4bit:
51+
kwargs['load_in_4bit'] = True # type: ignore
52+
kwargs['quantization_config'] = BitsAndBytesConfig(
53+
load_in_4bit=True,
54+
bnb_4bit_compute_dtype=torch.float16,
55+
bnb_4bit_use_double_quant=True,
56+
bnb_4bit_quant_type='nf4'
57+
)
58+
else:
59+
kwargs['torch_dtype'] = torch.float16 # type: ignore
60+
61+
if 'lita' not in model_name.lower():
62+
warnings.warn("this function is for loading LITA models", stacklevel=2)
63+
if 'lora' in model_name.lower():
64+
warnings.warn("lora is currently not supported for LITA", stacklevel=2)
65+
if 'mpt' in model_name.lower():
66+
warnings.warn("mpt is currently not supported for LITA", stacklevel=2)
67+
68+
if model_base is not None:
69+
print('Loading LITA from base model...')
70+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
71+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
72+
model = LitaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
73+
74+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
75+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items() if 'mm_projector' in k}
76+
model.load_state_dict(mm_projector_weights, strict=False)
77+
else:
78+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
79+
model = LitaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
80+
81+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
82+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", False)
83+
if mm_use_im_patch_token:
84+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
85+
if mm_use_im_start_end:
86+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
87+
model.resize_token_embeddings(len(tokenizer))
88+
89+
vision_tower = model.get_vision_tower()
90+
if not vision_tower.is_loaded:
91+
vision_tower.load_model()
92+
vision_tower.to(device=device, dtype=torch.float16)
93+
image_processor = vision_tower.image_processor
94+
95+
# time tokens and embeddings
96+
num_time_tokens = getattr(model.config, "num_time_tokens", 0)
97+
if num_time_tokens > 0:
98+
time_tokens = [TIME_TOKEN_TEMPLATE.format(t=x) for x in range(num_time_tokens)]
99+
num_new_tokens = tokenizer.add_tokens(time_tokens)
100+
101+
if model_base is None:
102+
assert num_new_tokens == 0, "time tokens should already be in the tokenizer for full finetune model"
103+
104+
if num_new_tokens > 0:
105+
warnings.warn("looking for weights in mm_projector.bin", stacklevel=2)
106+
assert num_new_tokens == num_time_tokens
107+
model.resize_token_embeddings(len(tokenizer))
108+
input_embeddings = model.get_input_embeddings().weight.data
109+
output_embeddings = model.get_output_embeddings().weight.data
110+
weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
111+
assert 'model.embed_tokens.weight' in weights and 'lm_head.weight' in weights
112+
113+
dtype = input_embeddings.dtype
114+
device = input_embeddings.device
115+
116+
tokenizer_time_token_ids = tokenizer.convert_tokens_to_ids(time_tokens)
117+
time_token_ids = getattr(model.config, 'time_token_ids', tokenizer_time_token_ids)
118+
input_embeddings[tokenizer_time_token_ids] = weights['model.embed_tokens.weight'][time_token_ids].to(dtype).to(device)
119+
output_embeddings[tokenizer_time_token_ids] = weights['lm_head.weight'][time_token_ids].to(dtype).to(device)
120+
121+
if hasattr(model.config, "max_sequence_length"):
122+
context_len = model.config.max_sequence_length
123+
else:
124+
context_len = 2048
125+
return tokenizer, model, image_processor, context_len
126+
127+
32128
def disable_torch_init() -> None:
33129
"""
34130
Disable the redundant torch default initialization to accelerate model creation.
@@ -79,7 +175,7 @@ def __init__(
79175

80176
disable_torch_init()
81177

82-
pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit)
178+
pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) # type: ignore
83179
self.tokenizer, self.model, self.processor, self.context_len = pretrainers
84180

85181
self.model_num_frames = self.model.config.num_frames

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ disable_error_code = ["import-not-found"]
7878

7979
[[tool.mypy.overrides]]
8080
module = "DPF.filters.videos.lita_filter"
81-
disable_error_code = ["import-not-found", "call-overload"]
81+
disable_error_code = ["import-not-found", "call-overload", "no-untyped-def"]
8282

8383
[[tool.mypy.overrides]]
8484
module = "DPF.filters.images.llava_captioning_filter"

0 commit comments

Comments
 (0)