there is a transformer model when i try to decode messages, translate input sentence to target sentence i get memory blow up, memory is good when i use transformer.fit(), but in a loop like below it blows up memory, tf.keras.backend.clear_session() does’t help, also accuracy decrease when i use that, gc.collect() doesn’t work also
here is my code
def decode_sequence(input_sentence): tokenized_input_sentence = input_vectorization([input_sentence]) decoded_sentence = START_TOKEN for i in tf.range(max_decoded_sentence_length): tokenized_target_sentence = output_vectorization([decoded_sentence])#[:, :-1] predictions = transformer([tokenized_input_sentence, tokenized_target_sentence]) sampled_token_index = np.argmax(predictions[0, i, :]) sampled_token = output_index_lookup[sampled_token_index] decoded_sentence += sampled_token if sampled_token == END_TOKEN: break gc.collect() return decoded_sentence from tqdm import tqdm def overall_accuracy(pairs): corrects = 0 inputs = pairs[2739:] iter = tqdm(inputs) for i, pair in enumerate(iter): input_text = pair[0] target = pair[1] predicted = decode_sequence(input_text) #guess = '✓' if predicted == target else '✗' #print('Sample Number : ', i, 'Predicted : ', predicted, 'Real : ', target, guess) if predicted == target: corrects += 1 iter.set_postfix(corrects=corrects, accuracy=corrects / (i + 1)) return corrects / len(inputs) print("Overall Acurracy : ", overall_accuracy(test_pairs))```