@@ -326,7 +326,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
326326
327327 global_tp_size = get_tensor_model_parallel_world_size ()
328328 global_tp_rank = get_tensor_model_parallel_rank ()
329-
329+ check_match = (lambda weight_name , module_name : weight_name .
330+ removesuffix (".weight" ) == module_name )
330331 for (
331332 org_weight_name ,
332333 mapped_weight_name ,
@@ -347,12 +348,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
347348 ) and mapped_weight_name .endswith (".weight" ):
348349 # Without sharding
349350 if any (
350- mapped_weight_name . startswith ( module )
351+ check_match ( mapped_weight_name , module )
351352 for module in self .unsharded_weights_modules ):
352353 weight_sub_tensor = weight_tensor
353354 # Shard by column
354355 elif any (
355- mapped_weight_name . startswith ( module )
356+ check_match ( mapped_weight_name , module )
356357 for module in self .column_sharded_weights_modules ):
357358 total_size = weight_tensor .size (- 1 )
358359 start_index = total_size // tp_size * tp_rank
@@ -362,14 +363,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
362363 # Weights have fused on disk. In this case, we assume that the
363364 # weight and module use same name.
364365 elif any (
365- mapped_weight_name . startswith ( module )
366+ check_match ( mapped_weight_name , module )
366367 for module in self .maybe_fused_weights_modules ):
367368 # special case for fused weights
368369 # get the size of each shard weight tensor
369370 total_shard_sizes = next (
370371 (sizes for module , sizes in
371372 self .maybe_fused_weights_modules .items ()
372- if mapped_weight_name . startswith ( module )))
373+ if check_match ( mapped_weight_name , module )))
373374 total_size = weight_tensor .size (0 )
374375 assert total_size == sum (total_shard_sizes )
375376 # get the start/end index of each shard weight tensor
0 commit comments