@@ -712,6 +712,187 @@ def _extract_patches(
712712 return patches
713713
714714
715+ class ExtractPatches3D (Operation ):
716+ def __init__ (
717+ self ,
718+ size ,
719+ strides = None ,
720+ dilation_rate = 1 ,
721+ padding = "valid" ,
722+ data_format = None ,
723+ * ,
724+ name = None ,
725+ ):
726+ super ().__init__ (name = name )
727+ if isinstance (size , int ):
728+ size = (size , size , size )
729+ elif len (size ) != 3 :
730+ raise TypeError (
731+ "Invalid `size` argument. Expected an "
732+ f"int or a tuple of length 3. Received: size={ size } "
733+ )
734+ self .size = size
735+ if strides is not None :
736+ if isinstance (strides , int ):
737+ strides = (strides , strides , strides )
738+ elif len (strides ) != 3 :
739+ raise ValueError (f"Invalid `strides` argument. Got: { strides } " )
740+ else :
741+ strides = size
742+ self .strides = strides
743+ self .dilation_rate = dilation_rate
744+ self .padding = padding
745+ self .data_format = backend .standardize_data_format (data_format )
746+
747+ def call (self , volumes ):
748+ return _extract_patches_3d (
749+ volumes ,
750+ self .size ,
751+ self .strides ,
752+ self .dilation_rate ,
753+ self .padding ,
754+ self .data_format ,
755+ )
756+
757+ def compute_output_spec (self , volumes ):
758+ volumes_shape = list (volumes .shape )
759+ original_ndim = len (volumes_shape )
760+ strides = self .strides
761+ if self .data_format == "channels_last" :
762+ channels_in = volumes_shape [- 1 ]
763+ else :
764+ channels_in = volumes_shape [- 4 ]
765+ if original_ndim == 4 :
766+ volumes_shape = [1 ] + volumes_shape
767+ filters = self .size [0 ] * self .size [1 ] * self .size [2 ] * channels_in
768+ kernel_size = (self .size [0 ], self .size [1 ], self .size [2 ])
769+ out_shape = compute_conv_output_shape (
770+ volumes_shape ,
771+ filters ,
772+ kernel_size ,
773+ strides = strides ,
774+ padding = self .padding ,
775+ data_format = self .data_format ,
776+ dilation_rate = self .dilation_rate ,
777+ )
778+ if original_ndim == 4 :
779+ out_shape = out_shape [1 :]
780+ return KerasTensor (shape = out_shape , dtype = volumes .dtype )
781+
782+
783+ def _extract_patches_3d (
784+ volumes ,
785+ size ,
786+ strides = None ,
787+ dilation_rate = 1 ,
788+ padding = "valid" ,
789+ data_format = None ,
790+ ):
791+ if isinstance (size , int ):
792+ patch_d = patch_h = patch_w = size
793+ elif len (size ) == 3 :
794+ patch_d , patch_h , patch_w = size
795+ else :
796+ raise TypeError (
797+ "Invalid `size` argument. Expected an "
798+ f"int or a tuple of length 3. Received: size={ size } "
799+ )
800+ if strides is None :
801+ strides = size
802+ if isinstance (strides , int ):
803+ strides = (strides , strides , strides )
804+ if len (strides ) != 3 :
805+ raise ValueError (f"Invalid `strides` argument. Got: { strides } " )
806+ data_format = backend .standardize_data_format (data_format )
807+ if data_format == "channels_last" :
808+ channels_in = volumes .shape [- 1 ]
809+ elif data_format == "channels_first" :
810+ channels_in = volumes .shape [- 4 ]
811+ out_dim = patch_d * patch_w * patch_h * channels_in
812+ kernel = backend .numpy .eye (out_dim , dtype = volumes .dtype )
813+ kernel = backend .numpy .reshape (
814+ kernel , (patch_d , patch_h , patch_w , channels_in , out_dim )
815+ )
816+ _unbatched = False
817+ if len (volumes .shape ) == 4 :
818+ _unbatched = True
819+ volumes = backend .numpy .expand_dims (volumes , axis = 0 )
820+ patches = backend .nn .conv (
821+ inputs = volumes ,
822+ kernel = kernel ,
823+ strides = strides ,
824+ padding = padding ,
825+ data_format = data_format ,
826+ dilation_rate = dilation_rate ,
827+ )
828+ if _unbatched :
829+ patches = backend .numpy .squeeze (patches , axis = 0 )
830+ return patches
831+
832+
833+ @keras_export ("keras.ops.image.extract_patches_3d" )
834+ def extract_patches_3d (
835+ volumes ,
836+ size ,
837+ strides = None ,
838+ dilation_rate = 1 ,
839+ padding = "valid" ,
840+ data_format = None ,
841+ ):
842+ """Extracts patches from the volume(s).
843+
844+ Args:
845+ volumes: Input volume or batch of volumes. Must be 4D or 5D.
846+ size: Patch size int or tuple (patch_depth, patch_height, patch_width)
847+ strides: strides along depth, height, and width. If not specified, or
848+ if `None`, it defaults to the same value as `size`.
849+ dilation_rate: This is the input stride, specifying how far two
850+ consecutive patch samples are in the input. Note that using
851+ `dilation_rate > 1` is not supported in conjunction with
852+ `strides > 1` on the TensorFlow backend.
853+ padding: The type of padding algorithm to use: `"same"` or `"valid"`.
854+ data_format: A string specifying the data format of the input tensor.
855+ It can be either `"channels_last"` or `"channels_first"`.
856+ `"channels_last"` corresponds to inputs with shape
857+ `(batch, depth, height, width, channels)`, while `"channels_first"`
858+ corresponds to inputs with shape
859+ `(batch, channels, depth, height, width)`. If not specified,
860+ the value will default to `keras.config.image_data_format()`.
861+
862+ Returns:
863+ Extracted patches 4D (if not batched) or 5D (if batched)
864+
865+ Examples:
866+
867+ >>> import numpy as np
868+ >>> import keras
869+ >>> # Batched case
870+ >>> volumes = np.random.random(
871+ ... (2, 10, 10, 10, 3)
872+ ... ).astype("float32") # batch of 2 volumes
873+ >>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
874+ >>> patches.shape
875+ (2, 3, 3, 3, 81)
876+ >>> # Unbatched case
877+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
878+ >>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
879+ >>> patches.shape
880+ (3, 3, 3, 81)
881+ """
882+ if any_symbolic_tensors ((volumes ,)):
883+ return ExtractPatches3D (
884+ size = size ,
885+ strides = strides ,
886+ dilation_rate = dilation_rate ,
887+ padding = padding ,
888+ data_format = data_format ,
889+ ).symbolic_call (volumes )
890+
891+ return _extract_patches_3d (
892+ volumes , size , strides , dilation_rate , padding , data_format = data_format
893+ )
894+
895+
715896class MapCoordinates (Operation ):
716897 def __init__ (self , order , fill_mode = "constant" , fill_value = 0 , * , name = None ):
717898 super ().__init__ (name = name )
0 commit comments