44import json
55from typing import Dict , List , Optional , Sequence , Set , Tuple , Type
66from typing_extensions import TypedDict
7- import re
87
98import libcst
109from libcst .codemod import CodemodContext , VisitorBasedCodemodCommand
@@ -90,37 +89,6 @@ class State:
9089}
9190
9291
93- class _SafeAction (argparse .Action ):
94- def __call__ (
95- self ,
96- parser : argparse .ArgumentParser ,
97- namespace : argparse .Namespace ,
98- values : object ,
99- option_string : Optional [str ] = ...,
100- ) -> None :
101- namespace .none_return = True
102- namespace .scalar_return = True
103- namespace .annotate_magics = True
104-
105-
106- class _AggressiveAction (_SafeAction ):
107- def __call__ (
108- self ,
109- parser : argparse .ArgumentParser ,
110- namespace : argparse .Namespace ,
111- values : object ,
112- option_string : Optional [str ] = ...,
113- ) -> None :
114- super ().__call__ (parser , namespace , values , option_string )
115- namespace .bool_param = True
116- namespace .int_param = True
117- namespace .float_param = True
118- namespace .str_param = True
119- namespace .bytes_param = True
120- namespace .annotate_imprecise_magics = True
121- namespace .guess_common_names = True
122-
123-
12492class AutotypeCommand (VisitorBasedCodemodCommand ):
12593 # Add a description so that future codemodders can see what this does.
12694 DESCRIPTION : str = "Automatically adds simple type annotations."
@@ -230,15 +198,15 @@ def add_args(arg_parser: argparse.ArgumentParser) -> None:
230198 )
231199 arg_parser .add_argument (
232200 "--safe" ,
233- action = _SafeAction ,
201+ action = "store_true" ,
202+ default = False ,
234203 help = "Apply all safe transformations" ,
235- nargs = "?" ,
236204 )
237205 arg_parser .add_argument (
238206 "--aggressive" ,
239- action = _AggressiveAction ,
207+ action = "store_true" ,
208+ default = False ,
240209 help = "Apply all transformations that do not require arguments" ,
241- nargs = "?" ,
242210 )
243211
244212 def __init__ (
@@ -262,6 +230,18 @@ def __init__(
262230 safe : bool = False ,
263231 aggressive : bool = False ,
264232 ) -> None :
233+ if safe or aggressive :
234+ none_return = True
235+ scalar_return = True
236+ annotate_magics = True
237+ if aggressive :
238+ bool_param = True
239+ int_param = True
240+ float_param = True
241+ str_param = True
242+ bytes_param = True
243+ annotate_imprecise_magics = True
244+ guess_common_names = True
265245 super ().__init__ (context )
266246 param_type_pairs = [
267247 (bool_param , bool ),
@@ -295,12 +275,16 @@ def __init__(
295275 )
296276 ] = metadata
297277 self .state = State (
298- annotate_optionals = [NamedParam .make (s ) for s in annotate_optional ]
299- if annotate_optional
300- else [],
301- annotate_named_params = [NamedParam .make (s ) for s in annotate_named_param ]
302- if annotate_named_param
303- else [],
278+ annotate_optionals = (
279+ [NamedParam .make (s ) for s in annotate_optional ]
280+ if annotate_optional
281+ else []
282+ ),
283+ annotate_named_params = (
284+ [NamedParam .make (s ) for s in annotate_named_param ]
285+ if annotate_named_param
286+ else []
287+ ),
304288 none_return = none_return ,
305289 scalar_return = scalar_return ,
306290 param_types = {typ for param , typ in param_type_pairs if param },
0 commit comments