@@ -190,67 +190,62 @@ def from_config(cls, config):
190190
191191### Cerebros model:
192192
193+ from transformers import AutoTokenizer
194+ import tensorflow as tf
195+
193196class NewTokenizerLayer (tf .keras .layers .Layer ):
194- """
195- A Keras layer that tokenizes input text using a specified tokenizer.
196- """
197197 def __init__ (self , max_seq_length , tokenizer_checkpoint , ** kwargs ):
198- """
199- Initializes the NewTokenizerLayer.
200- Args:
201- - max_seq_length (int): The maximum sequence length for tokenization.
202- - tokenizer_checkpoint (str): The checkpoint for the tokenizer to use.
203- - **kwargs: Additional keyword arguments for the layer.
204- """
205- super (NewTokenizerLayer , self ).__init__ (** kwargs )
206- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_checkpoint )
198+ super ().__init__ (** kwargs )
207199 self .max_seq_length = max_seq_length
200+ self .tokenizer_checkpoint = tokenizer_checkpoint
201+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_checkpoint )
202+
203+ # Ensure tokenizer has a padding token
204+ if self .tokenizer .pad_token is None :
205+ self .tokenizer .pad_token = self .tokenizer .eos_token
206+
208207 def call (self , inputs ):
209- """
210- Tokenizes the input text.
211- Args:
212- - inputs: The input text to tokenize.
213- Returns:
214- - The tokenized input IDs.
215- """
216- # Check if inputs is a tensor
217- # if isinstance(inputs, tf.Tensor):
218- # # Convert tensor to a list of strings
219- # inputs = inputs.numpy().astype("U").tolist()
220-
221- # inputs = [x.decode('utf-8') for x in inputs]
222- # inputs = tf.strings.unicode_encode(inputs, 'UTF-8')
208+ def tokenize_py_fn (inputs ):
209+ # Convert TensorFlow bytes to Python strings
210+ texts = [text .decode ('utf-8' ) for text in inputs .numpy ()]
211+
212+ # Tokenize with Hugging Face tokenizer
213+ tokenized = self .tokenizer (
214+ texts ,
215+ max_length = self .max_seq_length ,
216+ padding = 'max_length' ,
217+ truncation = True ,
218+ return_tensors = 'tf'
219+ )
220+ return tokenized ['input_ids' ].numpy ()
221+
222+ # Wrap Python function in TensorFlow operation
223+ input_ids = tf .py_function (
224+ tokenize_py_fn ,
225+ [inputs ],
226+ Tout = tf .int32
227+ )
228+
229+ # Set shape for downstream layers
230+ batch_size = tf .shape (inputs )[0 ]
231+ input_ids .set_shape ([None , self .max_seq_length ])
223232
224- tokenized = self .tokenizer (inputs .numpy ().astype ("U" ).tolist (),
225- max_length = self .max_seq_length ,
226- padding = 'max_length' ,
227- truncation = True ,
228- return_tensors = 'tf' ,
229- return_overflowing_tokens = False )
230- # Return the tokenized input IDs
231- return tokenized ['input_ids' ]
233+ return input_ids
234+
232235 def get_config (self ):
233- """
234- Returns the configuration for the layer.
235- Returns:
236- - A dictionary containing the layer's configuration.
237- """
238- config = super (NewTokenizerLayer , self ).get_config ()
236+ config = super ().get_config ()
239237 config .update ({
240238 'max_seq_length' : self .max_seq_length ,
241- 'tokenizer_checkpoint' : self .tokenizer . name_or_path
239+ 'tokenizer_checkpoint' : self .tokenizer_checkpoint
242240 })
243241 return config
242+
244243 @classmethod
245244 def from_config (cls , config ):
246- """
247- Creates a new instance of the layer from a configuration.
248- Args:
249- - config: The configuration dictionary.
250- Returns:
251- - A new instance of the layer.
252- """
253- return cls (max_seq_length = config ['max_seq_length' ], tokenizer_checkpoint = config ['tokenizer_checkpoint' ])
245+ return cls (
246+ max_seq_length = config ['max_seq_length' ],
247+ tokenizer_checkpoint = config ['tokenizer_checkpoint' ]
248+ )
254249
255250
256251
0 commit comments