@@ -11,6 +11,7 @@ from typing import (
1111 Generic ,
1212 Literal ,
1313 NamedTuple ,
14+ Protocol ,
1415 TypeVar ,
1516 final ,
1617 overload ,
@@ -208,28 +209,45 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
208209
209210_TT = TypeVar ("_TT" , bound = Literal [True , False ])
210211
212+ # ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945
213+ class DFCallable1 (Protocol [P ]): # ty: ignore[invalid-argument-type]
214+ def __call__ (
215+ self , df : DataFrame , / , * args : P .args , ** kwargs : P .kwargs
216+ ) -> Scalar | list | dict : ...
217+
218+ class DFCallable2 (Protocol [P ]): # ty: ignore[invalid-argument-type]
219+ def __call__ (
220+ self , df : DataFrame , / , * args : P .args , ** kwargs : P .kwargs
221+ ) -> DataFrame | Series : ...
222+
223+ class DFCallable3 (Protocol [P ]): # ty: ignore[invalid-argument-type]
224+ def __call__ (self , df : Iterable , / , * args : P .args , ** kwargs : P .kwargs ) -> float : ...
225+
211226class DataFrameGroupBy (GroupBy [DataFrame ], Generic [ByT , _TT ]):
212227 # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
213228 @overload # type: ignore[override]
214229 def apply (
215230 self ,
216- func : Callable [[DataFrame ], Scalar | list | dict ],
217- * args ,
218- ** kwargs ,
231+ func : DFCallable1 [P ],
232+ / ,
233+ * args : P .args ,
234+ ** kwargs : P .kwargs ,
219235 ) -> Series : ...
220236 @overload
221237 def apply (
222238 self ,
223- func : Callable [[DataFrame ], Series | DataFrame ],
224- * args ,
225- ** kwargs ,
239+ func : DFCallable2 [P ],
240+ / ,
241+ * args : P .args ,
242+ ** kwargs : P .kwargs ,
226243 ) -> DataFrame : ...
227244 @overload
228- def apply ( # pyright: ignore[reportOverlappingOverload]
245+ def apply (
229246 self ,
230- func : Callable [[Iterable ], float ],
231- * args ,
232- ** kwargs ,
247+ func : DFCallable3 [P ],
248+ / ,
249+ * args : P .args ,
250+ ** kwargs : P .kwargs ,
233251 ) -> DataFrame : ...
234252 # error: overload 1 overlaps overload 2 because of different return types
235253 @overload
0 commit comments