@@ -20,12 +20,14 @@ from cpython.object cimport (
2020import_datetime()
2121
2222import numpy as np
23+
2324cimport numpy as cnp
2425
2526cnp.import_array()
2627from numpy cimport (
2728 int64_t,
2829 ndarray,
30+ uint8_t,
2931)
3032
3133from pandas._libs.tslibs.util cimport get_c_string_buf_and_size
@@ -370,3 +372,81 @@ cpdef ndarray astype_overflowsafe(
370372 cnp.PyArray_MultiIter_NEXT(mi)
371373
372374 return iresult.view(dtype)
375+
376+
377+ # TODO: try to upstream this fix to numpy
378+ def compare_mismatched_resolutions (ndarray left , ndarray right , op ):
379+ """
380+ Overflow-safe comparison of timedelta64/datetime64 with mismatched resolutions.
381+
382+ >>> left = np.array([500], dtype="M8[Y]")
383+ >>> right = np.array([0], dtype="M8[ns]")
384+ >>> left < right # <- wrong!
385+ array([ True])
386+ """
387+
388+ if left.dtype.kind != right.dtype.kind or left.dtype.kind not in [" m" , " M" ]:
389+ raise ValueError (" left and right must both be timedelta64 or both datetime64" )
390+
391+ cdef:
392+ int op_code = op_to_op_code(op)
393+ NPY_DATETIMEUNIT left_unit = get_unit_from_dtype(left.dtype)
394+ NPY_DATETIMEUNIT right_unit = get_unit_from_dtype(right.dtype)
395+
396+ # equiv: result = np.empty((<object>left).shape, dtype="bool")
397+ ndarray result = cnp.PyArray_EMPTY(
398+ left.ndim, left.shape, cnp.NPY_BOOL, 0
399+ )
400+
401+ ndarray lvalues = left.view(" i8" )
402+ ndarray rvalues = right.view(" i8" )
403+
404+ cnp.broadcast mi = cnp.PyArray_MultiIterNew3(result, lvalues, rvalues)
405+ int64_t lval, rval
406+ bint res_value
407+
408+ Py_ssize_t i, N = left.size
409+ npy_datetimestruct ldts, rdts
410+
411+
412+ for i in range (N):
413+ # Analogous to: lval = lvalues[i]
414+ lval = (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 1 ))[0 ]
415+
416+ # Analogous to: rval = rvalues[i]
417+ rval = (< int64_t* > cnp.PyArray_MultiIter_DATA(mi, 2 ))[0 ]
418+
419+ if lval == NPY_DATETIME_NAT or rval == NPY_DATETIME_NAT:
420+ res_value = op_code == Py_NE
421+
422+ else :
423+ pandas_datetime_to_datetimestruct(lval, left_unit, & ldts)
424+ pandas_datetime_to_datetimestruct(rval, right_unit, & rdts)
425+
426+ res_value = cmp_dtstructs(& ldts, & rdts, op_code)
427+
428+ # Analogous to: result[i] = res_value
429+ (< uint8_t* > cnp.PyArray_MultiIter_DATA(mi, 0 ))[0 ] = res_value
430+
431+ cnp.PyArray_MultiIter_NEXT(mi)
432+
433+ return result
434+
435+
436+ import operator
437+
438+
439+ cdef int op_to_op_code(op):
440+ # TODO: should exist somewhere?
441+ if op is operator.eq:
442+ return Py_EQ
443+ if op is operator.ne:
444+ return Py_NE
445+ if op is operator.le:
446+ return Py_LE
447+ if op is operator.lt:
448+ return Py_LT
449+ if op is operator.ge:
450+ return Py_GE
451+ if op is operator.gt:
452+ return Py_GT
0 commit comments