Skip to content

Commit aad3b2e

Browse files
author
Rohan Jain
committed
avoid floating points for integral floor division
1 parent 37975eb commit aad3b2e

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,22 @@ def floordiv_compat(
127127
left: pa.ChunkedArray | pa.Array | pa.Scalar,
128128
right: pa.ChunkedArray | pa.Array | pa.Scalar,
129129
) -> pa.ChunkedArray:
130-
# Ensure int // int -> int mirroring Python/Numpy behavior
131-
# as pc.floor(pc.divide_checked(int, int)) -> float
132-
converted_left = cast_for_truediv(left, right)
133-
result = pc.floor(pc.divide(converted_left, right))
134-
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
130+
divided = pc.divide(left, right)
131+
if pa.types.is_integer(divided.type):
132+
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
133+
result = pc.if_else(
134+
pc.less(divided, 0),
135+
pc.if_else(has_remainder, pc.subtract(divided, 1), divided),
136+
divided,
137+
)
138+
# Ensure compatibility with older versions of pandas where
139+
# int8 // int64 returned int8 rather than int64.
135140
result = result.cast(left.type)
141+
else:
142+
assert pa.types.is_floating(divided.type) or pa.types.is_decimal(
143+
divided.type
144+
)
145+
result = pc.floor(divided)
136146
return result
137147

138148
ARROW_ARITHMETIC_FUNCS = {

pandas/tests/extension/test_arrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,14 @@ def test_arrow_floordiv_large_values():
32463246
tm.assert_series_equal(result, expected)
32473247

32483248

3249+
def test_arrow_floordiv_large_integral_result():
3250+
# GH XXXXX
3251+
a = pd.Series([18014398509481983], dtype="int64[pyarrow]")
3252+
expected = pd.Series([18014398509481983], dtype="int64[pyarrow]")
3253+
result = a // 1
3254+
tm.assert_series_equal(result, expected)
3255+
3256+
32493257
def test_string_to_datetime_parsing_cast():
32503258
# GH 56266
32513259
string_dates = ["2020-01-01 04:30:00", "2020-01-02 00:00:00", "2020-01-03 00:00:00"]

0 commit comments

Comments
 (0)