2424from pandas ._libs import NaT , iNaT , lib
2525import pandas ._libs .groupby as libgroupby
2626import pandas ._libs .reduction as libreduction
27- from pandas ._typing import F , FrameOrSeries , Label , Shape
27+ from pandas ._typing import ArrayLike , F , FrameOrSeries , Label , Shape
2828from pandas .errors import AbstractMethodError
2929from pandas .util ._decorators import cache_readonly
3030
@@ -445,6 +445,68 @@ def _get_cython_func_and_vals(
445445 raise
446446 return func , values
447447
448+ def _disallow_invalid_ops (self , values : ArrayLike , how : str ):
449+ """
450+ Check if we can do this operation with our cython functions.
451+
452+ Raises
453+ ------
454+ NotImplementedError
455+ This is either not a valid function for this dtype, or
456+ valid but not implemented in cython.
457+ """
458+ dtype = values .dtype
459+
460+ if is_categorical_dtype (dtype ) or is_sparse (dtype ):
461+ # categoricals are only 1d, so we
462+ # are not setup for dim transforming
463+ raise NotImplementedError (f"{ dtype } dtype not supported" )
464+ elif is_datetime64_any_dtype (dtype ):
465+ # we raise NotImplemented if this is an invalid operation
466+ # entirely, e.g. adding datetimes
467+ if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
468+ raise NotImplementedError (
469+ f"datetime64 type does not support { how } operations"
470+ )
471+ elif is_timedelta64_dtype (dtype ):
472+ if how in ["prod" , "cumprod" ]:
473+ raise NotImplementedError (
474+ f"timedelta64 type does not support { how } operations"
475+ )
476+
477+ def _ea_wrap_cython_operation (
478+ self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
479+ ) -> Tuple [np .ndarray , Optional [List [str ]]]:
480+ """
481+ If we have an ExtensionArray, unwrap, call _cython_operation, and
482+ re-wrap if appropriate.
483+ """
484+ # TODO: general case implementation overrideable by EAs.
485+ orig_values = values
486+
487+ if is_datetime64tz_dtype (values .dtype ) or is_period_dtype (values .dtype ):
488+ # All of the functions implemented here are ordinal, so we can
489+ # operate on the tz-naive equivalents
490+ values = values .view ("M8[ns]" )
491+ res_values , names = self ._cython_operation (
492+ kind , values , how , axis , min_count , ** kwargs
493+ )
494+ res_values = res_values .astype ("i8" , copy = False )
495+ # FIXME: this is wrong for rank, but not tested.
496+ result = type (orig_values )._simple_new (res_values , dtype = orig_values .dtype )
497+ return result , names
498+
499+ elif is_integer_dtype (values .dtype ) or is_bool_dtype (values .dtype ):
500+ # IntegerArray or BooleanArray
501+ values = ensure_int_or_float (values )
502+ res_values , names = self ._cython_operation (
503+ kind , values , how , axis , min_count , ** kwargs
504+ )
505+ result = maybe_cast_result (result = res_values , obj = orig_values , how = how )
506+ return result , names
507+
508+ raise NotImplementedError (values .dtype )
509+
448510 def _cython_operation (
449511 self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
450512 ) -> Tuple [np .ndarray , Optional [List [str ]]]:
@@ -454,8 +516,8 @@ def _cython_operation(
454516 Names is only useful when dealing with 2D results, like ohlc
455517 (see self._name_functions).
456518 """
457- assert kind in ["transform" , "aggregate" ]
458519 orig_values = values
520+ assert kind in ["transform" , "aggregate" ]
459521
460522 if values .ndim > 2 :
461523 raise NotImplementedError ("number of dimensions is currently limited to 2" )
@@ -466,30 +528,12 @@ def _cython_operation(
466528
467529 # can we do this operation with our cython functions
468530 # if not raise NotImplementedError
531+ self ._disallow_invalid_ops (values , how )
469532
470- # we raise NotImplemented if this is an invalid operation
471- # entirely, e.g. adding datetimes
472-
473- # categoricals are only 1d, so we
474- # are not setup for dim transforming
475- if is_categorical_dtype (values .dtype ) or is_sparse (values .dtype ):
476- raise NotImplementedError (f"{ values .dtype } dtype not supported" )
477- elif is_datetime64_any_dtype (values .dtype ):
478- if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
479- raise NotImplementedError (
480- f"datetime64 type does not support { how } operations"
481- )
482- elif is_timedelta64_dtype (values .dtype ):
483- if how in ["prod" , "cumprod" ]:
484- raise NotImplementedError (
485- f"timedelta64 type does not support { how } operations"
486- )
487-
488- if is_datetime64tz_dtype (values .dtype ):
489- # Cast to naive; we'll cast back at the end of the function
490- # TODO: possible need to reshape?
491- # TODO(EA2D):kludge can be avoided when 2D EA is allowed.
492- values = values .view ("M8[ns]" )
533+ if is_extension_array_dtype (values .dtype ):
534+ return self ._ea_wrap_cython_operation (
535+ kind , values , how , axis , min_count , ** kwargs
536+ )
493537
494538 is_datetimelike = needs_i8_conversion (values .dtype )
495539 is_numeric = is_numeric_dtype (values .dtype )
@@ -573,19 +617,9 @@ def _cython_operation(
573617 if swapped :
574618 result = result .swapaxes (0 , axis )
575619
576- if is_datetime64tz_dtype (orig_values .dtype ) or is_period_dtype (
577- orig_values .dtype
578- ):
579- # We need to use the constructors directly for these dtypes
580- # since numpy won't recognize them
581- # https://github.com/pandas-dev/pandas/issues/31471
582- result = type (orig_values )(result .astype (np .int64 ), dtype = orig_values .dtype )
583- elif is_datetimelike and kind == "aggregate" :
620+ if is_datetimelike and kind == "aggregate" :
584621 result = result .astype (orig_values .dtype )
585622
586- if is_extension_array_dtype (orig_values .dtype ):
587- result = maybe_cast_result (result = result , obj = orig_values , how = how )
588-
589623 return result , names
590624
591625 def _aggregate (
0 commit comments