Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 21d957a

Browse files
stubgen: fix handling of Protocol and add testcase (#12129)
### Description This PR fixes #12072 by correctly handling `Protocol` definitions. Previously, the `Protocol` base class was removed when generating type stubs which causes problems with other packages that want to use type definition (because they see it as a regular class, not as a `Protocol`). ## Test Plan Added a testcase to the stubgen testset. Co-authored-by: 97littleleaf11 <[email protected]>
1 parent aa0d186 commit 21d957a

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mypy/stubgen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,12 @@ def visit_class_def(self, o: ClassDef) -> None:
851851
base_types.append('metaclass=abc.ABCMeta')
852852
self.import_tracker.add_import('abc')
853853
self.import_tracker.require_name('abc')
854+
elif self.analyzed and o.info.is_protocol:
855+
type_str = 'Protocol'
856+
if o.info.type_vars:
857+
type_str += f'[{", ".join(o.info.type_vars)}]'
858+
base_types.append(type_str)
859+
self.add_typing_import('Protocol')
854860
if base_types:
855861
self.add('(%s)' % ', '.join(base_types))
856862
self.add(':\n')

test-data/unit/stubgen.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,6 +2581,30 @@ def f(x: int, y: int) -> int: ...
25812581
@t.overload
25822582
def f(x: t.Tuple[int, int]) -> int: ...
25832583

2584+
[case testProtocol_semanal]
2585+
from typing import Protocol, TypeVar
2586+
2587+
class P(Protocol):
2588+
def f(self, x: int, y: int) -> str:
2589+
...
2590+
2591+
T = TypeVar('T')
2592+
T2 = TypeVar('T2')
2593+
class PT(Protocol[T, T2]):
2594+
def f(self, x: T) -> T2:
2595+
...
2596+
2597+
[out]
2598+
from typing import Protocol, TypeVar
2599+
2600+
class P(Protocol):
2601+
def f(self, x: int, y: int) -> str: ...
2602+
T = TypeVar('T')
2603+
T2 = TypeVar('T2')
2604+
2605+
class PT(Protocol[T, T2]):
2606+
def f(self, x: T) -> T2: ...
2607+
25842608
[case testNonDefaultKeywordOnlyArgAfterAsterisk]
25852609
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
25862610
[out]

0 commit comments

Comments
 (0)