Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pandas-stubs/core/groupby/generic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from typing import (
Generic,
Literal,
NamedTuple,
Protocol,
TypeVar,
final,
overload,
Expand Down Expand Up @@ -208,26 +209,35 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):

_TT = TypeVar("_TT", bound=Literal[True, False])

class DFCallable1(Protocol):
def __call__(self, df: DataFrame, /, *args, **kwargs) -> Scalar | list | dict: ...

class DFCallable2(Protocol):
def __call__(self, df: DataFrame, /, *args, **kwargs) -> DataFrame | Series: ...

class DFCallable3(Protocol):
def __call__(self, df: Iterable, /, *args, **kwargs) -> float: ...

class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
@overload # type: ignore[override]
def apply(
self,
func: Callable[[DataFrame], Scalar | list | dict],
func: DFCallable1,
*args,
**kwargs,
) -> Series: ...
@overload
def apply(
self,
func: Callable[[DataFrame], Series | DataFrame],
func: DFCallable2,
*args,
**kwargs,
) -> DataFrame: ...
@overload
def apply( # pyright: ignore[reportOverlappingOverload]
self,
func: Callable[[Iterable], float],
func: DFCallable3,
*args,
**kwargs,
) -> DataFrame: ...
Expand Down
20 changes: 20 additions & 0 deletions tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,23 @@ def test_dataframe_value_counts() -> None:
Series,
np.int64,
)


def test_dataframe_apply_kwargs() -> None:
# GH 1266
df = DataFrame({"group": ["A", "A", "B", "B", "C"], "value": [10, 15, 10, 25, 30]})

def add_constant_to_mean(group: DataFrame, constant: int) -> DataFrame:
mean_val = group["value"].mean()
group["adjusted"] = mean_val + constant
return group

check(
assert_type(
df.groupby("group", group_keys=False)[["group", "value"]].apply(
add_constant_to_mean, constant=5
),
DataFrame,
),
DataFrame,
)