To reproduce:
!pip install tensorflow-text==2.7.0 import tensorflow_text as text import tensorflow_hub as hub # ... other tf imports.... strategy = tf.distribute.MirroredStrategy() print('Number of GPU: ' + str(strategy.num_replicas_in_sync)) # 1 or 2, shouldn't matter NUM_CLASS=2 with strategy.scope(): bert_preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3") bert_encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4") def get_model(): text_input = Input(shape=(), dtype=tf.string, name='text') preprocessed_text = bert_preprocess(text_input) outputs = bert_encoder(preprocessed_text) output_sequence = outputs['sequence_output'] x = Dense(NUM_CLASS, activation='sigmoid')(output_sequence) model = Model(inputs=[text_input], outputs = [x]) return model optimizer = Adam() model = get_model() model.compile(loss=CategoricalCrossentropy(from_logits=True),optimizer=optimizer,metrics=[Accuracy(), ],) model.summary() # <- look at the output 1 tf.keras.utils.plot_model(model, show_shapes=True, to_file='model.png') # <- look at the figure 1 with strategy.scope(): optimizer = Adam() model = get_model() model.compile(loss=CategoricalCrossentropy(from_logits=True),optimizer=optimizer,metrics=[Accuracy(), ],) model.summary() # <- compare with output 1, it has already lost it's shape tf.keras.utils.plot_model(model, show_shapes=True, to_file='model_scoped.png') # <- compare this figure too, for ease With scope, BERT loses seq_length, and it becomes None.
Model summary withOUT scope: (See there is 128 at the very last layer, which is seq_length)
Model: "model_6" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== text (InputLayer) [(None,)] 0 [] keras_layer_2 (KerasLayer) {'input_mask': (Non 0 ['text[0][0]'] e, 128), 'input_word_ids': (None, 128), 'input_type_ids': (None, 128)} keras_layer_3 (KerasLayer) multiple 109482241 ['keras_layer_2[6][0]', 'keras_layer_2[6][1]', 'keras_layer_2[6][2]'] dense_6 (Dense) (None, 128, 2) 1538 ['keras_layer_3[6][14]'] ================================================================================================== Total params: 109,483,779 Trainable params: 1,538 Non-trainable params: 109,482,241 __________________________________________________________________________________________________ Model with scope:
Model: "model_7" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== text (InputLayer) [(None,)] 0 [] keras_layer_2 (KerasLayer) {'input_mask': (Non 0 ['text[0][0]'] e, 128), 'input_word_ids': (None, 128), 'input_type_ids': (None, 128)} keras_layer_3 (KerasLayer) multiple 109482241 ['keras_layer_2[7][0]', 'keras_layer_2[7][1]', 'keras_layer_2[7][2]'] dense_7 (Dense) (None, None, 2) 1538 ['keras_layer_3[7][14]'] ================================================================================================== Total params: 109,483,779 Trainable params: 1,538 Non-trainable params: 109,482,241 __________________________________________________________________________________________________ If these image helps:
Another notable thing encoder_outputs is also missing if you take a look at the 2nd keras layer or 3rd layer of both model.