1515import multiprocessing .pool
1616from functools import partial
1717
18- from . import keras_module
18+ from . import get_keras_submodule
1919
20- keras = keras_module ()
20+ backend = get_keras_submodule ('backend' )
21+ keras_utils = get_keras_submodule ('utils' )
2122
2223try :
2324 from PIL import ImageEnhance
@@ -354,13 +355,13 @@ def array_to_img(x, data_format=None, scale=True):
354355 if pil_image is None :
355356 raise ImportError ('Could not import PIL.Image. '
356357 'The use of `array_to_img` requires PIL.' )
357- x = np .asarray (x , dtype = keras . backend .floatx ())
358+ x = np .asarray (x , dtype = backend .floatx ())
358359 if x .ndim != 3 :
359360 raise ValueError ('Expected image array to have rank 3 (single image). '
360361 'Got array with shape:' , x .shape )
361362
362363 if data_format is None :
363- data_format = keras . backend .image_data_format ()
364+ data_format = backend .image_data_format ()
364365 if data_format not in {'channels_first' , 'channels_last' }:
365366 raise ValueError ('Invalid data_format:' , data_format )
366367
@@ -400,13 +401,13 @@ def img_to_array(img, data_format=None):
400401 ValueError: if invalid `img` or `data_format` is passed.
401402 """
402403 if data_format is None :
403- data_format = keras . backend .image_data_format ()
404+ data_format = backend .image_data_format ()
404405 if data_format not in {'channels_first' , 'channels_last' }:
405406 raise ValueError ('Unknown data_format: ' , data_format )
406407 # Numpy array x has format (height, width, channel)
407408 # or (channel, height, width)
408409 # but original PIL image has format (width, height, channel)
409- x = np .asarray (img , dtype = keras . backend .floatx ())
410+ x = np .asarray (img , dtype = backend .floatx ())
410411 if len (x .shape ) == 3 :
411412 if data_format == 'channels_first' :
412413 x = x .transpose (2 , 0 , 1 )
@@ -697,7 +698,7 @@ def __init__(self,
697698 data_format = None ,
698699 validation_split = 0.0 ):
699700 if data_format is None :
700- data_format = keras . backend .image_data_format ()
701+ data_format = backend .image_data_format ()
701702 self .featurewise_center = featurewise_center
702703 self .samplewise_center = samplewise_center
703704 self .featurewise_std_normalization = featurewise_std_normalization
@@ -949,7 +950,7 @@ def standardize(self, x):
949950 if self .samplewise_center :
950951 x -= np .mean (x , keepdims = True )
951952 if self .samplewise_std_normalization :
952- x /= (np .std (x , keepdims = True ) + keras . backend .epsilon ())
953+ x /= (np .std (x , keepdims = True ) + backend .epsilon ())
953954
954955 if self .featurewise_center :
955956 if self .mean is not None :
@@ -961,7 +962,7 @@ def standardize(self, x):
961962 'first by calling `.fit(numpy_data)`.' )
962963 if self .featurewise_std_normalization :
963964 if self .std is not None :
964- x /= (self .std + keras . backend .epsilon ())
965+ x /= (self .std + backend .epsilon ())
965966 else :
966967 warnings .warn ('This ImageDataGenerator specifies '
967968 '`featurewise_std_normalization`, '
@@ -1165,7 +1166,7 @@ def fit(self, x,
11651166 this is how many augmentation passes over the data to use.
11661167 seed: Int (default: None). Random seed.
11671168 """
1168- x = np .asarray (x , dtype = keras . backend .floatx ())
1169+ x = np .asarray (x , dtype = backend .floatx ())
11691170 if x .ndim != 4 :
11701171 raise ValueError ('Input to `.fit()` should have rank 4. '
11711172 'Got array with shape: ' + str (x .shape ))
@@ -1188,7 +1189,7 @@ def fit(self, x,
11881189 if augment :
11891190 ax = np .zeros (
11901191 tuple ([rounds * x .shape [0 ]] + list (x .shape )[1 :]),
1191- dtype = keras . backend .floatx ())
1192+ dtype = backend .floatx ())
11921193 for r in range (rounds ):
11931194 for i in range (x .shape [0 ]):
11941195 ax [i + r * x .shape [0 ]] = self .random_transform (x [i ])
@@ -1206,7 +1207,7 @@ def fit(self, x,
12061207 broadcast_shape = [1 , 1 , 1 ]
12071208 broadcast_shape [self .channel_axis - 1 ] = x .shape [self .channel_axis ]
12081209 self .std = np .reshape (self .std , broadcast_shape )
1209- x /= (self .std + keras . backend .epsilon ())
1210+ x /= (self .std + backend .epsilon ())
12101211
12111212 if self .zca_whitening :
12121213 flat_x = np .reshape (
@@ -1217,7 +1218,7 @@ def fit(self, x,
12171218 self .principal_components = (u * s_inv ).dot (u .T )
12181219
12191220
1220- class Iterator (keras . utils .Sequence ):
1221+ class Iterator (keras_utils .Sequence ):
12211222 """Base class for image data iterators.
12221223
12231224 Every `Iterator` must implement the `_get_batches_of_transformed_samples`
@@ -1384,8 +1385,8 @@ def __init__(self, x, y, image_data_generator,
13841385 if y is not None :
13851386 y = y [split_idx :]
13861387 if data_format is None :
1387- data_format = keras . backend .image_data_format ()
1388- self .x = np .asarray (x , dtype = keras . backend .floatx ())
1388+ data_format = backend .image_data_format ()
1389+ self .x = np .asarray (x , dtype = backend .floatx ())
13891390 self .x_misc = x_misc
13901391 if self .x .ndim != 4 :
13911392 raise ValueError ('Input data in `NumpyArrayIterator` '
@@ -1421,12 +1422,12 @@ def __init__(self, x, y, image_data_generator,
14211422
14221423 def _get_batches_of_transformed_samples (self , index_array ):
14231424 batch_x = np .zeros (tuple ([len (index_array )] + list (self .x .shape )[1 :]),
1424- dtype = keras . backend .floatx ())
1425+ dtype = backend .floatx ())
14251426 for i , j in enumerate (index_array ):
14261427 x = self .x [j ]
14271428 params = self .image_data_generator .get_random_transform (x .shape )
14281429 x = self .image_data_generator .apply_transform (
1429- x .astype (keras . backend .floatx ()), params )
1430+ x .astype (backend .floatx ()), params )
14301431 x = self .image_data_generator .standardize (x )
14311432 batch_x [i ] = x
14321433
@@ -1625,7 +1626,7 @@ def __init__(self, directory, image_data_generator,
16251626 subset = None ,
16261627 interpolation = 'nearest' ):
16271628 if data_format is None :
1628- data_format = keras . backend .image_data_format ()
1629+ data_format = backend .image_data_format ()
16291630 self .directory = directory
16301631 self .image_data_generator = image_data_generator
16311632 self .target_size = tuple (target_size )
@@ -1722,7 +1723,7 @@ def __init__(self, directory, image_data_generator,
17221723 def _get_batches_of_transformed_samples (self , index_array ):
17231724 batch_x = np .zeros (
17241725 (len (index_array ),) + self .image_shape ,
1725- dtype = keras . backend .floatx ())
1726+ dtype = backend .floatx ())
17261727 grayscale = self .color_mode == 'grayscale'
17271728 # build batch of image data
17281729 for i , j in enumerate (index_array ):
@@ -1752,11 +1753,11 @@ def _get_batches_of_transformed_samples(self, index_array):
17521753 elif self .class_mode == 'sparse' :
17531754 batch_y = self .classes [index_array ]
17541755 elif self .class_mode == 'binary' :
1755- batch_y = self .classes [index_array ].astype (keras . backend .floatx ())
1756+ batch_y = self .classes [index_array ].astype (backend .floatx ())
17561757 elif self .class_mode == 'categorical' :
17571758 batch_y = np .zeros (
17581759 (len (batch_x ), self .num_classes ),
1759- dtype = keras . backend .floatx ())
1760+ dtype = backend .floatx ())
17601761 for i , label in enumerate (self .classes [index_array ]):
17611762 batch_y [i , label ] = 1.
17621763 else :
0 commit comments