Skip to content

Commit d54943e

Browse files
authored
TYP: patch for astral-sh/ty#1846 (#1547)
* astral-sh/ty#1846 * remove type variables * naming * #1530 * one more case
1 parent c021ca4 commit d54943e

File tree

4 files changed

+32
-63
lines changed

4 files changed

+32
-63
lines changed

pandas-stubs/core/groupby/generic.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,18 +215,17 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
215215

216216
_TT = TypeVar("_TT", bound=Literal[True, False])
217217

218-
# ty ignore needed because of https://github.com/astral-sh/ty/issues/157#issuecomment-3017337945
219-
class DFCallable1(Protocol[P]): # ty: ignore[invalid-argument-type]
218+
class DFCallable1(Protocol[P]):
220219
def __call__(
221220
self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs
222221
) -> Scalar | list[Any] | dict[Hashable, Any]: ...
223222

224-
class DFCallable2(Protocol[P]): # ty: ignore[invalid-argument-type]
223+
class DFCallable2(Protocol[P]):
225224
def __call__(
226225
self, df: DataFrame, /, *args: P.args, **kwargs: P.kwargs
227226
) -> DataFrame | Series: ...
228227

229-
class DFCallable3(Protocol[P]): # ty: ignore[invalid-argument-type]
228+
class DFCallable3(Protocol[P]):
230229
def __call__(
231230
self, df: Iterable[Any], /, *args: P.args, **kwargs: P.kwargs
232231
) -> float: ...

pandas-stubs/core/groupby/groupby.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ from pandas._typing import (
7171
from pandas.plotting import PlotAccessor
7272

7373
_ResamplerGroupBy: TypeAlias = (
74-
DatetimeIndexResamplerGroupby[NDFrameT] # ty: ignore[invalid-argument-type]
75-
| PeriodIndexResamplerGroupby[NDFrameT] # ty: ignore[invalid-argument-type]
76-
| TimedeltaIndexResamplerGroupby[NDFrameT] # ty: ignore[invalid-argument-type]
74+
DatetimeIndexResamplerGroupby[NDFrameT]
75+
| PeriodIndexResamplerGroupby[NDFrameT]
76+
| TimedeltaIndexResamplerGroupby[NDFrameT]
7777
)
7878

7979
class GroupBy(BaseGroupBy[NDFrameT]):

pandas-stubs/core/reshape/pivot.pyi

Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ from pandas.core.series import Series
2222
from pandas._typing import (
2323
AnyArrayLike,
2424
ArrayLike,
25-
HashableT1,
26-
HashableT2,
27-
HashableT3,
2825
Label,
2926
Scalar,
3027
ScalarT,
@@ -33,12 +30,16 @@ from pandas._typing import (
3330
)
3431

3532
_PivotAggCallable: TypeAlias = Callable[[Series], ScalarT]
36-
3733
_PivotAggFunc: TypeAlias = (
3834
_PivotAggCallable[ScalarT]
3935
| np.ufunc
4036
| Literal["mean", "sum", "count", "min", "max", "median", "std", "var"]
4137
)
38+
_PivotAggFuncTypes: TypeAlias = (
39+
_PivotAggFunc[ScalarT]
40+
| Sequence[_PivotAggFunc[ScalarT]]
41+
| Mapping[Any, _PivotAggFunc[ScalarT]]
42+
)
4243

4344
_NonIterableHashable: TypeAlias = (
4445
str
@@ -53,32 +54,22 @@ _NonIterableHashable: TypeAlias = (
5354
| pd.Timedelta
5455
)
5556

56-
_PivotTableIndexTypes: TypeAlias = (
57-
Label | Sequence[HashableT1] | Series | Grouper | None
58-
)
57+
_PivotTableIndexTypes: TypeAlias = Label | Sequence[Hashable] | Series | Grouper | None
5958
_PivotTableColumnsTypes: TypeAlias = (
60-
Label | Sequence[HashableT2] | Series | Grouper | None
59+
Label | Sequence[Hashable] | Series | Grouper | None
6160
)
62-
_PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None
61+
_PivotTableValuesTypes: TypeAlias = Label | Sequence[Hashable] | None
6362

6463
_ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike
6564
_Values: TypeAlias = SequenceNotStr[Any] | _ExtendedAnyArrayLike
6665

6766
@overload
6867
def pivot_table(
6968
data: DataFrame,
70-
values: _PivotTableValuesTypes[
71-
Hashable # ty: ignore[invalid-type-arguments]
72-
] = None,
73-
index: _PivotTableIndexTypes[Hashable] = None, # ty: ignore[invalid-type-arguments]
74-
columns: _PivotTableColumnsTypes[
75-
Hashable # ty: ignore[invalid-type-arguments]
76-
] = None,
77-
aggfunc: (
78-
_PivotAggFunc[Scalar]
79-
| Sequence[_PivotAggFunc[Scalar]]
80-
| Mapping[Any, _PivotAggFunc[Scalar]]
81-
) = "mean",
69+
values: _PivotTableValuesTypes = None,
70+
index: _PivotTableIndexTypes = None,
71+
columns: _PivotTableColumnsTypes = None,
72+
aggfunc: _PivotAggFuncTypes[Scalar] = "mean",
8273
fill_value: Scalar | None = None,
8374
margins: bool = False,
8475
dropna: bool = True,
@@ -91,21 +82,11 @@ def pivot_table(
9182
@overload
9283
def pivot_table(
9384
data: DataFrame,
94-
values: _PivotTableValuesTypes[
95-
Hashable # ty: ignore[invalid-type-arguments]
96-
] = None,
85+
values: _PivotTableValuesTypes = None,
9786
*,
9887
index: Grouper,
99-
columns: (
100-
_PivotTableColumnsTypes[Hashable] # ty: ignore[invalid-type-arguments]
101-
| np_ndarray
102-
| Index[Any]
103-
) = None,
104-
aggfunc: (
105-
_PivotAggFunc[Scalar]
106-
| Sequence[_PivotAggFunc[Scalar]]
107-
| Mapping[Any, _PivotAggFunc[Scalar]]
108-
) = "mean",
88+
columns: _PivotTableColumnsTypes | np_ndarray | Index[Any] = None,
89+
aggfunc: _PivotAggFuncTypes[Scalar] = "mean",
10990
fill_value: Scalar | None = None,
11091
margins: bool = False,
11192
dropna: bool = True,
@@ -116,21 +97,11 @@ def pivot_table(
11697
@overload
11798
def pivot_table(
11899
data: DataFrame,
119-
values: _PivotTableValuesTypes[
120-
Hashable # ty: ignore[invalid-type-arguments]
121-
] = None,
122-
index: (
123-
_PivotTableIndexTypes[Hashable] # ty: ignore[invalid-type-arguments]
124-
| np_ndarray
125-
| Index[Any]
126-
) = None,
100+
values: _PivotTableValuesTypes = None,
101+
index: _PivotTableIndexTypes | np_ndarray | Index[Any] = None,
127102
*,
128103
columns: Grouper,
129-
aggfunc: (
130-
_PivotAggFunc[Scalar]
131-
| Sequence[_PivotAggFunc[Scalar]]
132-
| Mapping[Any, _PivotAggFunc[Scalar]]
133-
) = "mean",
104+
aggfunc: _PivotAggFuncTypes[Scalar] = "mean",
134105
fill_value: Scalar | None = None,
135106
margins: bool = False,
136107
dropna: bool = True,
@@ -141,17 +112,17 @@ def pivot_table(
141112
def pivot(
142113
data: DataFrame,
143114
*,
144-
index: _NonIterableHashable | Sequence[HashableT1] = ...,
145-
columns: _NonIterableHashable | Sequence[HashableT2] = ...,
146-
values: _NonIterableHashable | Sequence[HashableT3] = ...,
115+
index: _NonIterableHashable | Sequence[Hashable] = ...,
116+
columns: _NonIterableHashable | Sequence[Hashable] = ...,
117+
values: _NonIterableHashable | Sequence[Hashable] = ...,
147118
) -> DataFrame: ...
148119
@overload
149120
def crosstab(
150121
index: _Values | list[_Values],
151122
columns: _Values | list[_Values],
152123
values: _Values,
153-
rownames: list[HashableT1] | None = ...,
154-
colnames: list[HashableT2] | None = ...,
124+
rownames: SequenceNotStr[Hashable] | None = ...,
125+
colnames: SequenceNotStr[Hashable] | None = ...,
155126
*,
156127
aggfunc: str | np.ufunc | Callable[[Series], float],
157128
margins: bool = ...,
@@ -164,8 +135,8 @@ def crosstab(
164135
index: _Values | list[_Values],
165136
columns: _Values | list[_Values],
166137
values: None = None,
167-
rownames: list[HashableT1] | None = ...,
168-
colnames: list[HashableT2] | None = ...,
138+
rownames: SequenceNotStr[Hashable] | None = ...,
139+
colnames: SequenceNotStr[Hashable] | None = ...,
169140
aggfunc: None = None,
170141
margins: bool = ...,
171142
margins_name: str = ...,

tests/test_pandas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,8 +1716,7 @@ def m2(x: pd.Series) -> int:
17161716
colnames: list[tuple[str]] = [("a",)]
17171717
check(
17181718
assert_type(
1719-
pd.crosstab(a, b, colnames=colnames, rownames=rownames),
1720-
pd.DataFrame,
1719+
pd.crosstab(a, b, colnames=colnames, rownames=rownames), pd.DataFrame
17211720
),
17221721
pd.DataFrame,
17231722
)

0 commit comments

Comments
 (0)