Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 388bafe

Browse files
TYP: widen einsum subscripts param to accept complex arrays in subscript-free calling convention
The first parameter of np.einsum was typed as str | _ArrayLikeInt_co, which caused false-positive mypy errors when passing float or complex arrays using the subscript-free (interleaved operands/subscripts) calling convention. Widen to str | _ArrayLikeComplex_co so that bool, uint, int, float, and complex arrays are all accepted as the first operand. Also adds reveal tests for float64 and complex128 arrays in the subscript-free convention. Fixes: #31513
1 parent cf16ec1 commit 388bafe

2 files changed

Lines changed: 20 additions & 13 deletions

File tree

numpy/_core/einsumfunc.pyi

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ type _CastingUnsafe = Literal["unsafe"]
3434
# Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
3535
@overload
3636
def einsum(
37-
subscripts: str | _ArrayLikeInt_co,
37+
subscripts: str | _ArrayLikeComplex_co,
3838
/,
3939
*operands: _ArrayLikeBool_co,
4040
out: None = None,
@@ -45,7 +45,7 @@ def einsum(
4545
) -> Any: ...
4646
@overload
4747
def einsum(
48-
subscripts: str | _ArrayLikeInt_co,
48+
subscripts: str | _ArrayLikeComplex_co,
4949
/,
5050
*operands: _ArrayLikeUInt_co,
5151
out: None = None,
@@ -56,7 +56,7 @@ def einsum(
5656
) -> Any: ...
5757
@overload
5858
def einsum(
59-
subscripts: str | _ArrayLikeInt_co,
59+
subscripts: str | _ArrayLikeComplex_co,
6060
/,
6161
*operands: _ArrayLikeInt_co,
6262
out: None = None,
@@ -67,7 +67,7 @@ def einsum(
6767
) -> Any: ...
6868
@overload
6969
def einsum(
70-
subscripts: str | _ArrayLikeInt_co,
70+
subscripts: str | _ArrayLikeComplex_co,
7171
/,
7272
*operands: _ArrayLikeFloat_co,
7373
out: None = None,
@@ -78,7 +78,7 @@ def einsum(
7878
) -> Any: ...
7979
@overload
8080
def einsum(
81-
subscripts: str | _ArrayLikeInt_co,
81+
subscripts: str | _ArrayLikeComplex_co,
8282
/,
8383
*operands: _ArrayLikeComplex_co,
8484
out: None = None,
@@ -89,7 +89,7 @@ def einsum(
8989
) -> Any: ...
9090
@overload
9191
def einsum(
92-
subscripts: str | _ArrayLikeInt_co,
92+
subscripts: str | _ArrayLikeComplex_co,
9393
/,
9494
*operands: Any,
9595
casting: _CastingUnsafe,
@@ -100,7 +100,7 @@ def einsum(
100100
) -> Any: ...
101101
@overload
102102
def einsum[OutT: NDArray[np.bool | np.number]](
103-
subscripts: str | _ArrayLikeInt_co,
103+
subscripts: str | _ArrayLikeComplex_co,
104104
/,
105105
*operands: _ArrayLikeComplex_co,
106106
out: OutT,
@@ -111,7 +111,7 @@ def einsum[OutT: NDArray[np.bool | np.number]](
111111
) -> OutT: ...
112112
@overload
113113
def einsum[OutT: NDArray[np.bool | np.number]](
114-
subscripts: str | _ArrayLikeInt_co,
114+
subscripts: str | _ArrayLikeComplex_co,
115115
/,
116116
*operands: Any,
117117
out: OutT,
@@ -123,7 +123,7 @@ def einsum[OutT: NDArray[np.bool | np.number]](
123123

124124
@overload
125125
def einsum(
126-
subscripts: str | _ArrayLikeInt_co,
126+
subscripts: str | _ArrayLikeComplex_co,
127127
/,
128128
*operands: _ArrayLikeObject_co,
129129
out: None = None,
@@ -134,7 +134,7 @@ def einsum(
134134
) -> Any: ...
135135
@overload
136136
def einsum(
137-
subscripts: str | _ArrayLikeInt_co,
137+
subscripts: str | _ArrayLikeComplex_co,
138138
/,
139139
*operands: Any,
140140
casting: _CastingUnsafe,
@@ -145,7 +145,7 @@ def einsum(
145145
) -> Any: ...
146146
@overload
147147
def einsum[OutT: NDArray[np.bool | np.number]](
148-
subscripts: str | _ArrayLikeInt_co,
148+
subscripts: str | _ArrayLikeComplex_co,
149149
/,
150150
*operands: _ArrayLikeObject_co,
151151
out: OutT,
@@ -156,7 +156,7 @@ def einsum[OutT: NDArray[np.bool | np.number]](
156156
) -> OutT: ...
157157
@overload
158158
def einsum[OutT: NDArray[np.bool | np.number]](
159-
subscripts: str | _ArrayLikeInt_co,
159+
subscripts: str | _ArrayLikeComplex_co,
160160
/,
161161
*operands: Any,
162162
out: OutT,
@@ -171,7 +171,7 @@ def einsum[OutT: NDArray[np.bool | np.number]](
171171
# NOTE: In practice the list consists of a `str` (first element)
172172
# and a variable number of integer tuples.
173173
def einsum_path(
174-
subscripts: str | _ArrayLikeInt_co,
174+
subscripts: str | _ArrayLikeComplex_co,
175175
/,
176176
*operands: _ArrayLikeComplex_co | _DTypeLikeObject,
177177
optimize: _OptimizeKind = "greedy",

numpy/typing/tests/data/reveal/einsumfunc.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,11 @@ assert_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_i), tuple[list[Any], str
3636
assert_type(np.einsum_path("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c), tuple[list[Any], str])
3737

3838
assert_type(np.einsum([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), Any)
39+
3940
assert_type(np.einsum_path([[1, 1], [1, 1]], AR_LIKE_i, AR_LIKE_i), tuple[list[Any], str])
41+
42+
# subscript-free calling convention with float and complex arrays
43+
AR_f: npt.NDArray[np.float64]
44+
AR_c: npt.NDArray[np.complex128]
45+
assert_type(np.einsum(AR_f, [0, 1], AR_f, [1, 0], [0]), Any)
46+
assert_type(np.einsum(AR_c, [0, 1], AR_c, [1, 0], [0]), Any)

0 commit comments

Comments
 (0)