33
44import numpy as np
55import pandas as pd
6+ import pytest
67from typing_extensions import assert_type
78
89from tests import (
@@ -44,6 +45,7 @@ def test_string_accessors_boolean_series():
4445 _check (assert_type (s .str .endswith ("e" ), "pd.Series[bool]" ))
4546 _check (assert_type (s .str .endswith (("e" , "f" )), "pd.Series[bool]" ))
4647 _check (assert_type (s .str .fullmatch ("apple" ), "pd.Series[bool]" ))
48+ _check (assert_type (s .str .fullmatch (re .compile (r"apple" )), "pd.Series[bool]" ))
4749 _check (assert_type (s .str .isalnum (), "pd.Series[bool]" ))
4850 _check (assert_type (s .str .isalpha (), "pd.Series[bool]" ))
4951 _check (assert_type (s .str .isdecimal (), "pd.Series[bool]" ))
@@ -54,6 +56,7 @@ def test_string_accessors_boolean_series():
5456 _check (assert_type (s .str .istitle (), "pd.Series[bool]" ))
5557 _check (assert_type (s .str .isupper (), "pd.Series[bool]" ))
5658 _check (assert_type (s .str .match ("pp" ), "pd.Series[bool]" ))
59+ _check (assert_type (s .str .match (re .compile (r"pp" )), "pd.Series[bool]" ))
5760
5861
5962def test_string_accessors_boolean_index ():
@@ -72,6 +75,7 @@ def test_string_accessors_boolean_index():
7275 _check (assert_type (idx .str .endswith ("e" ), np_ndarray_bool ))
7376 _check (assert_type (idx .str .endswith (("e" , "f" )), np_ndarray_bool ))
7477 _check (assert_type (idx .str .fullmatch ("apple" ), np_ndarray_bool ))
78+ _check (assert_type (idx .str .fullmatch (re .compile (r"apple" )), np_ndarray_bool ))
7579 _check (assert_type (idx .str .isalnum (), np_ndarray_bool ))
7680 _check (assert_type (idx .str .isalpha (), np_ndarray_bool ))
7781 _check (assert_type (idx .str .isdecimal (), np_ndarray_bool ))
@@ -82,6 +86,7 @@ def test_string_accessors_boolean_index():
8286 _check (assert_type (idx .str .istitle (), np_ndarray_bool ))
8387 _check (assert_type (idx .str .isupper (), np_ndarray_bool ))
8488 _check (assert_type (idx .str .match ("pp" ), np_ndarray_bool ))
89+ _check (assert_type (idx .str .match (re .compile (r"pp" )), np_ndarray_bool ))
8590
8691
8792def test_string_accessors_integer_series ():
@@ -94,6 +99,10 @@ def test_string_accessors_integer_series():
9499 _check (assert_type (s .str .count ("pp" ), "pd.Series[int]" ))
95100 _check (assert_type (s .str .len (), "pd.Series[int]" ))
96101
102+ # unlike findall, find doesn't accept a compiled pattern
103+ with pytest .raises (TypeError ):
104+ s .str .find (re .compile (r"p" )) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
105+
97106
98107def test_string_accessors_integer_index ():
99108 idx = pd .Index (DATA )
@@ -105,6 +114,10 @@ def test_string_accessors_integer_index():
105114 _check (assert_type (idx .str .count ("pp" ), "pd.Index[int]" ))
106115 _check (assert_type (idx .str .len (), "pd.Index[int]" ))
107116
117+ # unlike findall, find doesn't accept a compiled pattern
118+ with pytest .raises (TypeError ):
119+ idx .str .find (re .compile (r"p" )) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
120+
108121
109122def test_string_accessors_string_series ():
110123 s = pd .Series (DATA )
@@ -123,6 +136,9 @@ def test_string_accessors_string_series():
123136 _check (assert_type (s .str .removesuffix ("e" ), "pd.Series[str]" ))
124137 _check (assert_type (s .str .repeat (2 ), "pd.Series[str]" ))
125138 _check (assert_type (s .str .replace ("a" , "X" ), "pd.Series[str]" ))
139+ _check (
140+ assert_type (s .str .replace (re .compile (r"a" ), "X" , regex = True ), "pd.Series[str]" )
141+ )
126142 _check (assert_type (s .str .rjust (80 ), "pd.Series[str]" ))
127143 _check (assert_type (s .str .rstrip (), "pd.Series[str]" ))
128144 _check (assert_type (s .str .slice_replace (0 , 2 , "XX" ), "pd.Series[str]" ))
@@ -158,6 +174,9 @@ def test_string_accessors_string_index():
158174 _check (assert_type (idx .str .removesuffix ("e" ), "pd.Index[str]" ))
159175 _check (assert_type (idx .str .repeat (2 ), "pd.Index[str]" ))
160176 _check (assert_type (idx .str .replace ("a" , "X" ), "pd.Index[str]" ))
177+ _check (
178+ assert_type (idx .str .replace (re .compile (r"a" ), "X" , regex = True ), "pd.Index[str]" )
179+ )
161180 _check (assert_type (idx .str .rjust (80 ), "pd.Index[str]" ))
162181 _check (assert_type (idx .str .rstrip (), "pd.Index[str]" ))
163182 _check (assert_type (idx .str .slice_replace (0 , 2 , "XX" ), "pd.Index[str]" ))
@@ -190,29 +209,49 @@ def test_string_accessors_list_series():
190209 s = pd .Series (DATA )
191210 _check = functools .partial (check , klass = pd .Series , dtype = list )
192211 _check (assert_type (s .str .findall ("pp" ), "pd.Series[list[str]]" ))
212+ _check (assert_type (s .str .findall (re .compile (r"pp" )), "pd.Series[list[str]]" ))
193213 _check (assert_type (s .str .split ("a" ), "pd.Series[list[str]]" ))
214+ _check (assert_type (s .str .split (re .compile (r"a" )), "pd.Series[list[str]]" ))
194215 # GH 194
195216 _check (assert_type (s .str .split ("a" , expand = False ), "pd.Series[list[str]]" ))
196217 _check (assert_type (s .str .rsplit ("a" ), "pd.Series[list[str]]" ))
197218 _check (assert_type (s .str .rsplit ("a" , expand = False ), "pd.Series[list[str]]" ))
198219
220+ # rsplit doesn't accept compiled pattern
221+ # it doesn't raise at runtime but produces a nan
222+ bad_rsplit_result = s .str .rsplit (
223+ re .compile (r"a" ) # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
224+ )
225+ assert bad_rsplit_result .isna ().all ()
226+
199227
200228def test_string_accessors_list_index ():
201229 idx = pd .Index (DATA )
202230 _check = functools .partial (check , klass = pd .Index , dtype = list )
203231 _check (assert_type (idx .str .findall ("pp" ), "pd.Index[list[str]]" ))
232+ _check (assert_type (idx .str .findall (re .compile (r"pp" )), "pd.Index[list[str]]" ))
204233 _check (assert_type (idx .str .split ("a" ), "pd.Index[list[str]]" ))
234+ _check (assert_type (idx .str .split (re .compile (r"a" )), "pd.Index[list[str]]" ))
205235 # GH 194
206236 _check (assert_type (idx .str .split ("a" , expand = False ), "pd.Index[list[str]]" ))
207237 _check (assert_type (idx .str .rsplit ("a" ), "pd.Index[list[str]]" ))
208238 _check (assert_type (idx .str .rsplit ("a" , expand = False ), "pd.Index[list[str]]" ))
209239
240+ # rsplit doesn't accept compiled pattern
241+ # it doesn't raise at runtime but produces a nan
242+ bad_rsplit_result = idx .str .rsplit (
243+ re .compile (r"a" ) # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
244+ )
245+ assert bad_rsplit_result .isna ().all ()
246+
210247
211248def test_string_accessors_expanding_series ():
212249 s = pd .Series (["a1" , "b2" , "c3" ])
213250 _check = functools .partial (check , klass = pd .DataFrame )
214251 _check (assert_type (s .str .extract (r"([ab])?(\d)" ), pd .DataFrame ))
252+ _check (assert_type (s .str .extract (re .compile (r"([ab])?(\d)" )), pd .DataFrame ))
215253 _check (assert_type (s .str .extractall (r"([ab])?(\d)" ), pd .DataFrame ))
254+ _check (assert_type (s .str .extractall (re .compile (r"([ab])?(\d)" )), pd .DataFrame ))
216255 _check (assert_type (s .str .get_dummies (), pd .DataFrame ))
217256 _check (assert_type (s .str .partition ("p" ), pd .DataFrame ))
218257 _check (assert_type (s .str .rpartition ("p" ), pd .DataFrame ))
@@ -231,7 +270,15 @@ def test_string_accessors_expanding_index():
231270
232271 # These ones are the odd ones out?
233272 check (assert_type (idx .str .extractall (r"([ab])?(\d)" ), pd .DataFrame ), pd .DataFrame )
273+ check (
274+ assert_type (idx .str .extractall (re .compile (r"([ab])?(\d)" )), pd .DataFrame ),
275+ pd .DataFrame ,
276+ )
234277 check (assert_type (idx .str .extract (r"([ab])?(\d)" ), pd .DataFrame ), pd .DataFrame )
278+ check (
279+ assert_type (idx .str .extract (re .compile (r"([ab])?(\d)" )), pd .DataFrame ),
280+ pd .DataFrame ,
281+ )
235282
236283
237284def test_series_overloads_partition ():
0 commit comments