Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d073a10

Browse files
committed
Update boilerplate to include annotations from pyi files
comment about methods without type hints in boilerplate
1 parent 1ccb7f0 commit d073a10

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

tools/boilerplate.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# runtime with the proper signatures, a static pyplot.py is simpler for static
1414
# analysis tools to parse.
1515

16+
import ast
1617
from enum import Enum
1718
import inspect
1819
from inspect import Parameter
@@ -117,6 +118,17 @@ def __repr__(self):
117118
return self._repr
118119

119120

121+
class direct_repr:
122+
"""
123+
A placeholder class to destringify annotations from ast
124+
"""
125+
def __init__(self, value):
126+
self._repr = value
127+
128+
def __repr__(self):
129+
return self._repr
130+
131+
120132
def generate_function(name, called_fullname, template, **kwargs):
121133
"""
122134
Create a wrapper function *pyplot_name* calling *call_name*.
@@ -153,14 +165,17 @@ def generate_function(name, called_fullname, template, **kwargs):
153165
# redecorated with make_keyword_only by _copy_docstring_and_deprecators.
154166
if decorator and decorator.func is _api.make_keyword_only:
155167
meth = meth.__wrapped__
156-
signature = inspect.signature(meth)
168+
169+
annotated_trees = get_ast_mro_trees(class_)
170+
signature = get_matching_signature(meth, annotated_trees)
171+
157172
# Replace self argument.
158173
params = list(signature.parameters.values())[1:]
159174
signature = str(signature.replace(parameters=[
160175
param.replace(default=value_formatter(param.default))
161176
if param.default is not param.empty else param
162177
for param in params]))
163-
if len('def ' + name + signature) >= 80:
178+
if len('def ' + name + signature) >= 80 and False:
164179
# Move opening parenthesis before newline.
165180
signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1)
166181
# How to call the wrapped function.
@@ -381,6 +396,73 @@ def build_pyplot(pyplot_path):
381396
pyplot.writelines(boilerplate_gen())
382397

383398

399+
### Methods for retrieving signatures from pyi stub files
400+
401+
def get_ast_tree(cls):
402+
path = Path(inspect.getfile(cls))
403+
stubpath = path.with_suffix(".pyi")
404+
path = stubpath if stubpath.exists() else path
405+
tree = ast.parse(path.read_text())
406+
for item in tree.body:
407+
if isinstance(item, ast.ClassDef) and item.name == cls.__name__:
408+
return item
409+
raise ValueError("Cannot find {cls.__name__} in ast")
410+
411+
412+
def get_ast_mro_trees(cls):
413+
return [get_ast_tree(c) for c in cls.__mro__ if c.__module__ != "builtins"]
414+
415+
416+
def get_matching_signature(method, trees):
417+
sig = inspect.signature(method)
418+
for tree in trees:
419+
for item in tree.body:
420+
if not isinstance(item, ast.FunctionDef):
421+
continue
422+
if item.name == method.__name__:
423+
return update_sig_from_node(item, sig)
424+
# The following methods are implemented outside of the mro of Axes
425+
# and thus do not get their annotated versions found with current code
426+
# stackplot
427+
# streamplot
428+
# table
429+
# tricontour
430+
# tricontourf
431+
# tripcolor
432+
# triplot
433+
434+
# import warnings
435+
# warnings.warn(f"'{method.__name__}' not found")
436+
return sig
437+
438+
439+
def update_sig_from_node(node, sig):
440+
params = dict(sig.parameters)
441+
args = node.args
442+
allargs = (
443+
args.posonlyargs
444+
+ args.args
445+
+ [args.vararg]
446+
+ args.kwonlyargs
447+
+ [args.kwarg]
448+
)
449+
for param in allargs:
450+
if param is None:
451+
continue
452+
if param.annotation is None:
453+
continue
454+
annotation = direct_repr(ast.unparse(param.annotation))
455+
params[param.arg] = params[param.arg].replace(annotation=annotation)
456+
457+
if node.returns is not None:
458+
return inspect.Signature(
459+
params.values(),
460+
return_annotation=direct_repr(ast.unparse(node.returns))
461+
)
462+
else:
463+
return inspect.Signature(params.values())
464+
465+
384466
if __name__ == '__main__':
385467
# Write the matplotlib.pyplot file.
386468
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)