Skip to content

Commit 18e0364

Browse files
Support for extracting volume patches (#21759)
* support for volume patching * correct `call` * fix pydoc * fix `test_extract_volume_patches_basic` casting residual * fix `test_extract_volume_patches_same_padding` casting * fix `test_extract_volume_patches_overlapping` casting * add extra testing * Update keras/src/ops/image.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix dimension orderiing * Update keras/src/ops/image.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix dimensional ordering + add test * fix docstring * docstring corrections * add validation checks for operation * comment * set default data format in tests, add a few parametrized tests to validate channels_first * add back test for channels first/last in dilation * fix dilation test * Update keras/src/ops/image.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * delete redundant prop set * Update keras/src/ops/image_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/ops/image_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * rename per francois feedback * fix pydoc to reflect method * fix name of operation + test --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent dc5e42c commit 18e0364

File tree

4 files changed

+402
-0
lines changed

4 files changed

+402
-0
lines changed

keras/api/_tf_keras/keras/ops/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.ops.image import crop_images as crop_images
99
from keras.src.ops.image import elastic_transform as elastic_transform
1010
from keras.src.ops.image import extract_patches as extract_patches
11+
from keras.src.ops.image import extract_patches_3d as extract_patches_3d
1112
from keras.src.ops.image import gaussian_blur as gaussian_blur
1213
from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
1314
from keras.src.ops.image import map_coordinates as map_coordinates

keras/api/ops/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.ops.image import crop_images as crop_images
99
from keras.src.ops.image import elastic_transform as elastic_transform
1010
from keras.src.ops.image import extract_patches as extract_patches
11+
from keras.src.ops.image import extract_patches_3d as extract_patches_3d
1112
from keras.src.ops.image import gaussian_blur as gaussian_blur
1213
from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
1314
from keras.src.ops.image import map_coordinates as map_coordinates

keras/src/ops/image.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,187 @@ def _extract_patches(
712712
return patches
713713

714714

715+
class ExtractPatches3D(Operation):
716+
def __init__(
717+
self,
718+
size,
719+
strides=None,
720+
dilation_rate=1,
721+
padding="valid",
722+
data_format=None,
723+
*,
724+
name=None,
725+
):
726+
super().__init__(name=name)
727+
if isinstance(size, int):
728+
size = (size, size, size)
729+
elif len(size) != 3:
730+
raise TypeError(
731+
"Invalid `size` argument. Expected an "
732+
f"int or a tuple of length 3. Received: size={size}"
733+
)
734+
self.size = size
735+
if strides is not None:
736+
if isinstance(strides, int):
737+
strides = (strides, strides, strides)
738+
elif len(strides) != 3:
739+
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
740+
else:
741+
strides = size
742+
self.strides = strides
743+
self.dilation_rate = dilation_rate
744+
self.padding = padding
745+
self.data_format = backend.standardize_data_format(data_format)
746+
747+
def call(self, volumes):
748+
return _extract_patches_3d(
749+
volumes,
750+
self.size,
751+
self.strides,
752+
self.dilation_rate,
753+
self.padding,
754+
self.data_format,
755+
)
756+
757+
def compute_output_spec(self, volumes):
758+
volumes_shape = list(volumes.shape)
759+
original_ndim = len(volumes_shape)
760+
strides = self.strides
761+
if self.data_format == "channels_last":
762+
channels_in = volumes_shape[-1]
763+
else:
764+
channels_in = volumes_shape[-4]
765+
if original_ndim == 4:
766+
volumes_shape = [1] + volumes_shape
767+
filters = self.size[0] * self.size[1] * self.size[2] * channels_in
768+
kernel_size = (self.size[0], self.size[1], self.size[2])
769+
out_shape = compute_conv_output_shape(
770+
volumes_shape,
771+
filters,
772+
kernel_size,
773+
strides=strides,
774+
padding=self.padding,
775+
data_format=self.data_format,
776+
dilation_rate=self.dilation_rate,
777+
)
778+
if original_ndim == 4:
779+
out_shape = out_shape[1:]
780+
return KerasTensor(shape=out_shape, dtype=volumes.dtype)
781+
782+
783+
def _extract_patches_3d(
784+
volumes,
785+
size,
786+
strides=None,
787+
dilation_rate=1,
788+
padding="valid",
789+
data_format=None,
790+
):
791+
if isinstance(size, int):
792+
patch_d = patch_h = patch_w = size
793+
elif len(size) == 3:
794+
patch_d, patch_h, patch_w = size
795+
else:
796+
raise TypeError(
797+
"Invalid `size` argument. Expected an "
798+
f"int or a tuple of length 3. Received: size={size}"
799+
)
800+
if strides is None:
801+
strides = size
802+
if isinstance(strides, int):
803+
strides = (strides, strides, strides)
804+
if len(strides) != 3:
805+
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
806+
data_format = backend.standardize_data_format(data_format)
807+
if data_format == "channels_last":
808+
channels_in = volumes.shape[-1]
809+
elif data_format == "channels_first":
810+
channels_in = volumes.shape[-4]
811+
out_dim = patch_d * patch_w * patch_h * channels_in
812+
kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
813+
kernel = backend.numpy.reshape(
814+
kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
815+
)
816+
_unbatched = False
817+
if len(volumes.shape) == 4:
818+
_unbatched = True
819+
volumes = backend.numpy.expand_dims(volumes, axis=0)
820+
patches = backend.nn.conv(
821+
inputs=volumes,
822+
kernel=kernel,
823+
strides=strides,
824+
padding=padding,
825+
data_format=data_format,
826+
dilation_rate=dilation_rate,
827+
)
828+
if _unbatched:
829+
patches = backend.numpy.squeeze(patches, axis=0)
830+
return patches
831+
832+
833+
@keras_export("keras.ops.image.extract_patches_3d")
834+
def extract_patches_3d(
835+
volumes,
836+
size,
837+
strides=None,
838+
dilation_rate=1,
839+
padding="valid",
840+
data_format=None,
841+
):
842+
"""Extracts patches from the volume(s).
843+
844+
Args:
845+
volumes: Input volume or batch of volumes. Must be 4D or 5D.
846+
size: Patch size int or tuple (patch_depth, patch_height, patch_width)
847+
strides: strides along depth, height, and width. If not specified, or
848+
if `None`, it defaults to the same value as `size`.
849+
dilation_rate: This is the input stride, specifying how far two
850+
consecutive patch samples are in the input. Note that using
851+
`dilation_rate > 1` is not supported in conjunction with
852+
`strides > 1` on the TensorFlow backend.
853+
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
854+
data_format: A string specifying the data format of the input tensor.
855+
It can be either `"channels_last"` or `"channels_first"`.
856+
`"channels_last"` corresponds to inputs with shape
857+
`(batch, depth, height, width, channels)`, while `"channels_first"`
858+
corresponds to inputs with shape
859+
`(batch, channels, depth, height, width)`. If not specified,
860+
the value will default to `keras.config.image_data_format()`.
861+
862+
Returns:
863+
Extracted patches 4D (if not batched) or 5D (if batched)
864+
865+
Examples:
866+
867+
>>> import numpy as np
868+
>>> import keras
869+
>>> # Batched case
870+
>>> volumes = np.random.random(
871+
... (2, 10, 10, 10, 3)
872+
... ).astype("float32") # batch of 2 volumes
873+
>>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
874+
>>> patches.shape
875+
(2, 3, 3, 3, 81)
876+
>>> # Unbatched case
877+
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
878+
>>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
879+
>>> patches.shape
880+
(3, 3, 3, 81)
881+
"""
882+
if any_symbolic_tensors((volumes,)):
883+
return ExtractPatches3D(
884+
size=size,
885+
strides=strides,
886+
dilation_rate=dilation_rate,
887+
padding=padding,
888+
data_format=data_format,
889+
).symbolic_call(volumes)
890+
891+
return _extract_patches_3d(
892+
volumes, size, strides, dilation_rate, padding, data_format=data_format
893+
)
894+
895+
715896
class MapCoordinates(Operation):
716897
def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
717898
super().__init__(name=name)

0 commit comments

Comments
 (0)