Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit f69929f

Browse files
committed
Make sure image iterator classes subclass the tf.keras Sequence class if it is available
1 parent d7225f2 commit f69929f

File tree

4 files changed

+28
-1
lines changed

4 files changed

+28
-1
lines changed

keras_preprocessing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def get_keras_submodule(name):
4040
return _KERAS_UTILS
4141

4242

43-
__version__ = '1.1.0'
43+
__version__ = '1.1.1'

keras_preprocessing/image/dataframe_iterator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ class DataFrameIterator(BatchFromFilesMixin, Iterator):
8989
'binary', 'categorical', 'input', 'multi_output', 'raw', 'sparse', None
9090
}
9191

92+
def __new__(cls, *args, **kwargs):
93+
try:
94+
from tensorflow.keras.utils import Sequence as TFSequence
95+
if TFSequence not in cls.__bases__:
96+
cls.__bases__ = cls.__bases__ + (TFSequence,)
97+
except ImportError:
98+
pass
99+
return super(DataFrameIterator, cls).__new__(cls)
100+
92101
def __init__(self,
93102
dataframe,
94103
directory=None,

keras_preprocessing/image/directory_iterator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ class DirectoryIterator(BatchFromFilesMixin, Iterator):
6464
"""
6565
allowed_class_modes = {'categorical', 'binary', 'sparse', 'input', None}
6666

67+
def __new__(cls, *args, **kwargs):
68+
try:
69+
from tensorflow.keras.utils import Sequence as TFSequence
70+
if TFSequence not in cls.__bases__:
71+
cls.__bases__ = cls.__bases__ + (TFSequence,)
72+
except ImportError:
73+
pass
74+
return super(DirectoryIterator, cls).__new__(cls)
75+
6776
def __init__(self,
6877
directory,
6978
image_data_generator,

keras_preprocessing/image/numpy_array_iterator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ class NumpyArrayIterator(Iterator):
4242
dtype: Dtype to use for the generated arrays.
4343
"""
4444

45+
def __new__(cls, *args, **kwargs):
46+
try:
47+
from tensorflow.keras.utils import Sequence as TFSequence
48+
if TFSequence not in cls.__bases__:
49+
cls.__bases__ = cls.__bases__ + (TFSequence,)
50+
except ImportError:
51+
pass
52+
return super(NumpyArrayIterator, cls).__new__(cls)
53+
4554
def __init__(self,
4655
x,
4756
y,

0 commit comments

Comments
 (0)