-
- Notifications
You must be signed in to change notification settings - Fork 141
Open
Description
🐛 Bug description
The real-time progress table doesn't show up when training using a dataset formatted with ImageDataGenerator.
Here is my code when using ImageDataGenerator (that doesn't work):
from keras.preprocessing.image import ImageDataGenerator datagen = ImageDataGenerator( rescale=1.0/255.0, validation_split=0.2 ) train_gen = datagen.flow_from_dataframe( dataframe= df, directory=None, x_col="image_path", y_col=training_df.columns[1:], subset="training", batch_size=32, shuffle=True, class_mode="raw", target_size=(224, 224), ) valid_gen = datagen.flow_from_dataframe( dataframe= df, directory=None, x_col="image_path", y_col=training_df.columns[1:], subset="validation", batch_size=32, shuffle=True, class_mode="raw", target_size=(224, 224), )from livelossplot import PlotLossesKeras plotlosses = PlotLossesKeras() history = model.fit( train_gen, validation_data=valid_gen, epochs=10, callbacks=[plotlosses], verbose=0 )
And here is the working code using a X_train, Y_train, X_val and Y_val method:
from sklearn.model_selection import train_test_split from keras.preprocessing.image import load_img, img_to_array import numpy as np def load_image(image_path, target_size=(224,224)): img = load_img(image_path, target_size=target_size) img = img_to_array(img) img = img / 255.0 # Normalize to [0,1] return img subset_df = df.sample(n=10, random_state=42) train_df, val_df = train_test_split(subset_df, test_size=0.2, random_state=42) X_train = np.array([load_image(path) for path in train_df['image_path']]) Y_train = train_df[train_df.columns[1:]].values # Assumes first column is image_path, rest are labels X_val = np.array([load_image(path) for path in val_df['image_path']]) Y_val = val_df[val_df.columns[1:]].valuesfrom livelossplot import PlotLossesKeras plotlosses = PlotLossesKeras() history = model.fit( X_train, Y_train, validation_data=(X_val, Y_val), epochs=10, callbacks=[plotlosses], verbose=0, )
(don't worry for the weird stats, it was just a demo training on 10 images)
Environment
- livelossplot version: 0.5.5
- OS: MacOS Ventura
- Environment in which the error occurred: Jupyter Notebook
- Python version: Python 3.10.11
Metadata
Metadata
Assignees
Labels
No labels