@@ -498,6 +498,14 @@ def remove_all_adapters(self):
498498 self ._active_adapters .clear ()
499499
500500 def _create_lora_modules (self ):
501+
502+ def _parent_module (module_name : str ) -> str :
503+ # module name is a dot separated name.
504+ # for example:
505+ # - given an input 'x.y.z' return 'x.y'
506+ # - given an input 'x' return ''
507+ return module_name .rpartition ('.' )[0 ]
508+
501509 for module_name , module in self .model .named_modules (
502510 remove_duplicate = False ):
503511 if isinstance (module , PPMissingLayer ):
@@ -529,10 +537,17 @@ def _create_lora_modules(self):
529537 new_module .scaling_factor_to_offset
530538 # (yard1): TODO make this more robust
531539 if "lm_head" in module_name :
540+ logits_processor_module_name = 'logits_processor'
541+ parent_module = _parent_module (module_name )
542+ if parent_module :
543+ logits_processor_module_name = (
544+ f"{ parent_module } .{ logits_processor_module_name } " )
545+
532546 logits_processor_module = self .model .get_submodule (
533- "logits_processor" )
547+ logits_processor_module_name )
548+
534549 new_module = replace_submodule (
535- self .model , "logits_processor" ,
550+ self .model , logits_processor_module_name ,
536551 from_layer_logits_processor (logits_processor_module ,
537552 module , self .lora_slots ,
538553 self .lora_config ,
0 commit comments