Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit 6e57e54

Browse files
committed
New, restricted design that fixes circular dependency issues.
1 parent 732a7ce commit 6e57e54

File tree

7 files changed

+101
-56
lines changed

7 files changed

+101
-56
lines changed

keras_preprocessing/__init__.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,35 @@
44
from __future__ import division
55
from __future__ import print_function
66

7-
_KERAS_MODULE = None
7+
_KERAS_BACKEND = None
8+
_KERAS_UTILS = None
89

910

10-
def keras_module():
11-
global _KERAS_MODULE
12-
if _KERAS_MODULE is None:
13-
# Use `import keras` as default
14-
set_keras_module('keras')
15-
if _KERAS_MODULE == 'tensorflow.keras':
16-
# Due to TF namespace structure,
17-
# can't `__import__('tensorflow.keras')`.
18-
# Use workaround.
19-
tf = __import__('tensorflow')
20-
keras = tf.keras
21-
else:
22-
keras = __import__(_KERAS_MODULE, fromlist=['keras'])
23-
# TODO: check that the Keras version is compatible with
24-
# the current module.
25-
return keras
11+
def set_keras_submodules(backend, utils):
12+
global _KERAS_BACKEND
13+
global _KERAS_UTILS
14+
_KERAS_BACKEND = backend
15+
_KERAS_UTILS = utils
2616

2717

28-
def set_keras_module(module):
29-
global _KERAS_MODULE
30-
_KERAS_MODULE = module
18+
def get_keras_submodule(name):
19+
if name not in {'backend', 'utils'}:
20+
raise ImportError(
21+
'Can only retrieve "backend" and "utils". '
22+
'Requested: %s' % name)
23+
if _KERAS_BACKEND is None or _KERAS_UTILS is None:
24+
raise ImportError('You need to first `import keras` '
25+
'in order to use `keras_preprocessing`. '
26+
'For instance, you can do:\n\n'
27+
'```\n'
28+
'import keras\n'
29+
'from keras_preprocessing import image\n'
30+
'```\n\n'
31+
'Or, preferably, this equivalent formulation:\n\n'
32+
'```\n'
33+
'from keras import preprocessing\n'
34+
'```\n')
35+
if name == 'backend':
36+
return _KERAS_BACKEND
37+
elif name == 'utils':
38+
return _KERAS_UTILS

keras_preprocessing/image.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
import multiprocessing.pool
1616
from functools import partial
1717

18-
from . import keras_module
18+
from . import get_keras_submodule
1919

20-
keras = keras_module()
20+
backend = get_keras_submodule('backend')
21+
keras_utils = get_keras_submodule('utils')
2122

2223
try:
2324
from PIL import ImageEnhance
@@ -354,13 +355,13 @@ def array_to_img(x, data_format=None, scale=True):
354355
if pil_image is None:
355356
raise ImportError('Could not import PIL.Image. '
356357
'The use of `array_to_img` requires PIL.')
357-
x = np.asarray(x, dtype=keras.backend.floatx())
358+
x = np.asarray(x, dtype=backend.floatx())
358359
if x.ndim != 3:
359360
raise ValueError('Expected image array to have rank 3 (single image). '
360361
'Got array with shape:', x.shape)
361362

362363
if data_format is None:
363-
data_format = keras.backend.image_data_format()
364+
data_format = backend.image_data_format()
364365
if data_format not in {'channels_first', 'channels_last'}:
365366
raise ValueError('Invalid data_format:', data_format)
366367

@@ -400,13 +401,13 @@ def img_to_array(img, data_format=None):
400401
ValueError: if invalid `img` or `data_format` is passed.
401402
"""
402403
if data_format is None:
403-
data_format = keras.backend.image_data_format()
404+
data_format = backend.image_data_format()
404405
if data_format not in {'channels_first', 'channels_last'}:
405406
raise ValueError('Unknown data_format: ', data_format)
406407
# Numpy array x has format (height, width, channel)
407408
# or (channel, height, width)
408409
# but original PIL image has format (width, height, channel)
409-
x = np.asarray(img, dtype=keras.backend.floatx())
410+
x = np.asarray(img, dtype=backend.floatx())
410411
if len(x.shape) == 3:
411412
if data_format == 'channels_first':
412413
x = x.transpose(2, 0, 1)
@@ -697,7 +698,7 @@ def __init__(self,
697698
data_format=None,
698699
validation_split=0.0):
699700
if data_format is None:
700-
data_format = keras.backend.image_data_format()
701+
data_format = backend.image_data_format()
701702
self.featurewise_center = featurewise_center
702703
self.samplewise_center = samplewise_center
703704
self.featurewise_std_normalization = featurewise_std_normalization
@@ -949,7 +950,7 @@ def standardize(self, x):
949950
if self.samplewise_center:
950951
x -= np.mean(x, keepdims=True)
951952
if self.samplewise_std_normalization:
952-
x /= (np.std(x, keepdims=True) + keras.backend.epsilon())
953+
x /= (np.std(x, keepdims=True) + backend.epsilon())
953954

954955
if self.featurewise_center:
955956
if self.mean is not None:
@@ -961,7 +962,7 @@ def standardize(self, x):
961962
'first by calling `.fit(numpy_data)`.')
962963
if self.featurewise_std_normalization:
963964
if self.std is not None:
964-
x /= (self.std + keras.backend.epsilon())
965+
x /= (self.std + backend.epsilon())
965966
else:
966967
warnings.warn('This ImageDataGenerator specifies '
967968
'`featurewise_std_normalization`, '
@@ -1165,7 +1166,7 @@ def fit(self, x,
11651166
this is how many augmentation passes over the data to use.
11661167
seed: Int (default: None). Random seed.
11671168
"""
1168-
x = np.asarray(x, dtype=keras.backend.floatx())
1169+
x = np.asarray(x, dtype=backend.floatx())
11691170
if x.ndim != 4:
11701171
raise ValueError('Input to `.fit()` should have rank 4. '
11711172
'Got array with shape: ' + str(x.shape))
@@ -1188,7 +1189,7 @@ def fit(self, x,
11881189
if augment:
11891190
ax = np.zeros(
11901191
tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
1191-
dtype=keras.backend.floatx())
1192+
dtype=backend.floatx())
11921193
for r in range(rounds):
11931194
for i in range(x.shape[0]):
11941195
ax[i + r * x.shape[0]] = self.random_transform(x[i])
@@ -1206,7 +1207,7 @@ def fit(self, x,
12061207
broadcast_shape = [1, 1, 1]
12071208
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
12081209
self.std = np.reshape(self.std, broadcast_shape)
1209-
x /= (self.std + keras.backend.epsilon())
1210+
x /= (self.std + backend.epsilon())
12101211

12111212
if self.zca_whitening:
12121213
flat_x = np.reshape(
@@ -1217,7 +1218,7 @@ def fit(self, x,
12171218
self.principal_components = (u * s_inv).dot(u.T)
12181219

12191220

1220-
class Iterator(keras.utils.Sequence):
1221+
class Iterator(keras_utils.Sequence):
12211222
"""Base class for image data iterators.
12221223
12231224
Every `Iterator` must implement the `_get_batches_of_transformed_samples`
@@ -1384,8 +1385,8 @@ def __init__(self, x, y, image_data_generator,
13841385
if y is not None:
13851386
y = y[split_idx:]
13861387
if data_format is None:
1387-
data_format = keras.backend.image_data_format()
1388-
self.x = np.asarray(x, dtype=keras.backend.floatx())
1388+
data_format = backend.image_data_format()
1389+
self.x = np.asarray(x, dtype=backend.floatx())
13891390
self.x_misc = x_misc
13901391
if self.x.ndim != 4:
13911392
raise ValueError('Input data in `NumpyArrayIterator` '
@@ -1421,12 +1422,12 @@ def __init__(self, x, y, image_data_generator,
14211422

14221423
def _get_batches_of_transformed_samples(self, index_array):
14231424
batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
1424-
dtype=keras.backend.floatx())
1425+
dtype=backend.floatx())
14251426
for i, j in enumerate(index_array):
14261427
x = self.x[j]
14271428
params = self.image_data_generator.get_random_transform(x.shape)
14281429
x = self.image_data_generator.apply_transform(
1429-
x.astype(keras.backend.floatx()), params)
1430+
x.astype(backend.floatx()), params)
14301431
x = self.image_data_generator.standardize(x)
14311432
batch_x[i] = x
14321433

@@ -1625,7 +1626,7 @@ def __init__(self, directory, image_data_generator,
16251626
subset=None,
16261627
interpolation='nearest'):
16271628
if data_format is None:
1628-
data_format = keras.backend.image_data_format()
1629+
data_format = backend.image_data_format()
16291630
self.directory = directory
16301631
self.image_data_generator = image_data_generator
16311632
self.target_size = tuple(target_size)
@@ -1722,7 +1723,7 @@ def __init__(self, directory, image_data_generator,
17221723
def _get_batches_of_transformed_samples(self, index_array):
17231724
batch_x = np.zeros(
17241725
(len(index_array),) + self.image_shape,
1725-
dtype=keras.backend.floatx())
1726+
dtype=backend.floatx())
17261727
grayscale = self.color_mode == 'grayscale'
17271728
# build batch of image data
17281729
for i, j in enumerate(index_array):
@@ -1752,11 +1753,11 @@ def _get_batches_of_transformed_samples(self, index_array):
17521753
elif self.class_mode == 'sparse':
17531754
batch_y = self.classes[index_array]
17541755
elif self.class_mode == 'binary':
1755-
batch_y = self.classes[index_array].astype(keras.backend.floatx())
1756+
batch_y = self.classes[index_array].astype(backend.floatx())
17561757
elif self.class_mode == 'categorical':
17571758
batch_y = np.zeros(
17581759
(len(batch_x), self.num_classes),
1759-
dtype=keras.backend.floatx())
1760+
dtype=backend.floatx())
17601761
for i, label in enumerate(self.classes[index_array]):
17611762
batch_y[i, label] = 1.
17621763
else:

keras_preprocessing/sequence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import random
1010
from six.moves import range
1111

12-
from . import keras_module
12+
from . import get_keras_submodule
1313

14-
keras = keras_module()
14+
keras_utils = get_keras_submodule('utils')
1515

1616

1717
def pad_sequences(sequences, maxlen=None, dtype='int32',
@@ -250,7 +250,7 @@ def _remove_long_seq(maxlen, seq, label):
250250
return new_seq, new_label
251251

252252

253-
class TimeseriesGenerator(keras.utils.Sequence):
253+
class TimeseriesGenerator(keras_utils.Sequence):
254254
"""Utility class for generating batches of temporal data.
255255
256256
This class takes in a sequence of data-points gathered at

tests/image_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
import tempfile
66
import shutil
77

8+
import keras
9+
10+
# TODO: remove the 3 lines below once the Keras release
11+
# is configured to use keras_preprocessing
12+
import keras_preprocessing
13+
keras_preprocessing.set_keras_submodules(
14+
backend=keras.backend, utils=keras.utils)
15+
16+
# This enables this import
817
from keras_preprocessing import image
918

1019

tests/integration_test.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,31 @@
55
from importlib import reload
66

77

8-
def test_dynamic_backend_setting():
8+
def test_that_internal_imports_are_not_overriden():
9+
# Test that changing the keras module after importing
10+
# Keras does not override keras.preprocessing's keras module
911
import keras_preprocessing
1012
reload(keras_preprocessing)
11-
assert keras_preprocessing._KERAS_MODULE is None
12-
from tensorflow import keras as keras_ref
13-
keras_preprocessing.set_keras_module('tensorflow.keras')
14-
from keras_preprocessing import image
15-
assert image.keras_module() is keras_ref
16-
17-
import keras as keras_ref
18-
keras_preprocessing.set_keras_module('keras')
19-
assert image.keras_module() is keras_ref
20-
reload(image)
21-
assert image.keras_module() is keras_ref
13+
assert keras_preprocessing._KERAS_BACKEND is None
14+
15+
import keras
16+
if not hasattr(keras.preprocessing.image, 'image'):
17+
return # Old Keras, don't run.
18+
19+
import tensorflow as tf
20+
keras_preprocessing.set_keras_submodules(backend=tf.keras.backend,
21+
utils=tf.keras.utils)
22+
assert keras.preprocessing.image.image.backend is keras.backend
23+
24+
# Now test the reverse order
25+
del keras
26+
reload(keras_preprocessing)
27+
assert keras_preprocessing._KERAS_BACKEND is None
28+
29+
keras_preprocessing.set_keras_submodules(backend=tf.keras.backend,
30+
utils=tf.keras.utils)
31+
import keras
32+
assert keras.preprocessing.image.image.backend is keras.backend
2233

2334

2435
if __name__ == '__main__':

tests/sequence_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from numpy.testing import assert_allclose
55
from numpy.testing import assert_raises
66

7+
import keras
8+
9+
# TODO: remove the 3 lines below once the Keras release
10+
# is configured to use keras_preprocessing
11+
import keras_preprocessing
12+
keras_preprocessing.set_keras_submodules(
13+
backend=keras.backend, utils=keras.utils)
14+
715
from keras_preprocessing import sequence
816

917

tests/text_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
import numpy as np
33
import pytest
44

5+
import keras
6+
7+
# TODO: remove the 3 lines below once the Keras release
8+
# is configured to use keras_preprocessing
9+
import keras_preprocessing
10+
keras_preprocessing.set_keras_submodules(
11+
backend=keras.backend, utils=keras.utils)
12+
513
from keras_preprocessing import text
614

715

0 commit comments

Comments
 (0)