Skip to content

Commit 25900cc

Browse files
committed
fix: change device choice in model loader function
1 parent b0682ef commit 25900cc

File tree

1 file changed

+95
-2
lines changed

1 file changed

+95
-2
lines changed

DPF/filters/videos/lita_filter.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,22 @@
99
DEFAULT_IM_START_TOKEN,
1010
DEFAULT_IMAGE_TOKEN,
1111
IMAGE_TOKEN_INDEX,
12+
TIME_TOKEN_TEMPLATE,
1213
)
13-
from lita.model.builder import load_pretrained_model
14+
from lita.model.language_model.lita_llama import LitaLlamaForCausalLM
1415
from lita.utils import load_video
1516
from llava.conversation import SeparatorStyle, conv_templates
1617
from llava.mm_utils import (
1718
KeywordsStoppingCriteria,
1819
get_model_name_from_path,
1920
tokenizer_image_token,
2021
)
22+
from transformers import (
23+
AutoConfig,
24+
AutoModelForCausalLM,
25+
AutoTokenizer,
26+
BitsAndBytesConfig,
27+
)
2128

2229
from DPF.types import ModalityToDataMapping
2330

@@ -29,6 +36,92 @@
2936
from torch.utils.data import default_collate
3037

3138

39+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
40+
kwargs = {"device_map": device_map}
41+
42+
if device != "cuda":
43+
kwargs['device_map'] = {"": device}
44+
45+
if load_8bit:
46+
kwargs['load_in_8bit'] = True
47+
elif load_4bit:
48+
kwargs['load_in_4bit'] = True
49+
kwargs['quantization_config'] = BitsAndBytesConfig(
50+
load_in_4bit=True,
51+
bnb_4bit_compute_dtype=torch.float16,
52+
bnb_4bit_use_double_quant=True,
53+
bnb_4bit_quant_type='nf4'
54+
)
55+
else:
56+
kwargs['torch_dtype'] = torch.float16
57+
58+
if 'lita' not in model_name.lower():
59+
warnings.warn("this function is for loading LITA models")
60+
if 'lora' in model_name.lower():
61+
warnings.warn("lora is currently not supported for LITA")
62+
if 'mpt' in model_name.lower():
63+
warnings.warn("mpt is currently not supported for LITA")
64+
65+
if model_base is not None:
66+
print('Loading LITA from base model...')
67+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
68+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
69+
model = LitaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
70+
71+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
72+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items() if 'mm_projector' in k}
73+
model.load_state_dict(mm_projector_weights, strict=False)
74+
else:
75+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
76+
model = LitaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
77+
78+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
79+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", False)
80+
if mm_use_im_patch_token:
81+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
82+
if mm_use_im_start_end:
83+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
84+
model.resize_token_embeddings(len(tokenizer))
85+
86+
vision_tower = model.get_vision_tower()
87+
if not vision_tower.is_loaded:
88+
vision_tower.load_model()
89+
vision_tower.to(device=device, dtype=torch.float16)
90+
image_processor = vision_tower.image_processor
91+
92+
# time tokens and embeddings
93+
num_time_tokens = getattr(model.config, "num_time_tokens", 0)
94+
if num_time_tokens > 0:
95+
time_tokens = [TIME_TOKEN_TEMPLATE.format(t=x) for x in range(num_time_tokens)]
96+
num_new_tokens = tokenizer.add_tokens(time_tokens)
97+
98+
if model_base is None:
99+
assert num_new_tokens == 0, "time tokens should already be in the tokenizer for full finetune model"
100+
101+
if num_new_tokens > 0:
102+
warnings.warn("looking for weights in mm_projector.bin")
103+
assert num_new_tokens == num_time_tokens
104+
model.resize_token_embeddings(len(tokenizer))
105+
input_embeddings = model.get_input_embeddings().weight.data
106+
output_embeddings = model.get_output_embeddings().weight.data
107+
weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
108+
assert 'model.embed_tokens.weight' in weights and 'lm_head.weight' in weights
109+
110+
dtype = input_embeddings.dtype
111+
device = input_embeddings.device
112+
113+
tokenizer_time_token_ids = tokenizer.convert_tokens_to_ids(time_tokens)
114+
time_token_ids = getattr(model.config, 'time_token_ids', tokenizer_time_token_ids)
115+
input_embeddings[tokenizer_time_token_ids] = weights['model.embed_tokens.weight'][time_token_ids].to(dtype).to(device)
116+
output_embeddings[tokenizer_time_token_ids] = weights['lm_head.weight'][time_token_ids].to(dtype).to(device)
117+
118+
if hasattr(model.config, "max_sequence_length"):
119+
context_len = model.config.max_sequence_length
120+
else:
121+
context_len = 2048
122+
return tokenizer, model, image_processor, context_len
123+
124+
32125
def disable_torch_init() -> None:
33126
"""
34127
Disable the redundant torch default initialization to accelerate model creation.
@@ -79,7 +172,7 @@ def __init__(
79172

80173
disable_torch_init()
81174

82-
pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit)
175+
pretrainers = load_pretrained_model(weights_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
83176
self.tokenizer, self.model, self.processor, self.context_len = pretrainers
84177

85178
self.model_num_frames = self.model.config.num_frames

0 commit comments

Comments
 (0)