File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -48,20 +48,20 @@ def load_pretrained_model(model_path: str,
4848 kwargs = {"device_map" : device_map }
4949
5050 if device != "cuda" :
51- kwargs ['device_map' ] = {"" : device }
51+ kwargs ['device_map' ] = {"" : device } # type: ignore
5252
5353 if load_8bit :
54- kwargs ['load_in_8bit' ] = True
54+ kwargs ['load_in_8bit' ] = True # type: ignore
5555 elif load_4bit :
56- kwargs ['load_in_4bit' ] = True
56+ kwargs ['load_in_4bit' ] = True # type: ignore
5757 kwargs ['quantization_config' ] = BitsAndBytesConfig (
5858 load_in_4bit = True ,
5959 bnb_4bit_compute_dtype = torch .float16 ,
6060 bnb_4bit_use_double_quant = True ,
6161 bnb_4bit_quant_type = 'nf4'
6262 )
6363 else :
64- kwargs ['torch_dtype' ] = torch .float16
64+ kwargs ['torch_dtype' ] = torch .float16 # type: ignore
6565
6666 if 'lita' not in model_name .lower ():
6767 warnings .warn ("this function is for loading LITA models" )
You can’t perform that action at this time.
0 commit comments