@@ -90,11 +90,18 @@ def get_fixers_from_package(pkg_name):
9090 for fix_name in get_all_fix_names (pkg_name , False )]
9191
9292
93+ class FixerError (Exception ):
94+ """A fixer could not be loaded."""
95+
96+
9397class RefactoringTool (object ):
9498
9599 _default_options = {"print_function" : False }
96100
97- def __init__ (self , fixer_names , options = None , explicit = []):
101+ CLASS_PREFIX = "Fix" # The prefix for fixer classes
102+ FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
103+
104+ def __init__ (self , fixer_names , options = None , explicit = None ):
98105 """Initializer.
99106
100107 Args:
@@ -103,7 +110,7 @@ def __init__(self, fixer_names, options=None, explicit=[]):
103110 explicit: a list of fixers to run even if they are explicit.
104111 """
105112 self .fixers = fixer_names
106- self .explicit = explicit
113+ self .explicit = explicit or []
107114 self .options = self ._default_options .copy ()
108115 if options is not None :
109116 self .options .update (options )
@@ -134,29 +141,17 @@ def get_fixers(self):
134141 pre_order_fixers = []
135142 post_order_fixers = []
136143 for fix_mod_path in self .fixers :
137- try :
138- mod = __import__ (fix_mod_path , {}, {}, ["*" ])
139- except ImportError :
140- self .log_error ("Can't load transformation module %s" ,
141- fix_mod_path )
142- continue
144+ mod = __import__ (fix_mod_path , {}, {}, ["*" ])
143145 fix_name = fix_mod_path .rsplit ("." , 1 )[- 1 ]
144- if fix_name .startswith ("fix_" ):
145- fix_name = fix_name [4 :]
146+ if fix_name .startswith (self . FILE_PREFIX ):
147+ fix_name = fix_name [len ( self . FILE_PREFIX ) :]
146148 parts = fix_name .split ("_" )
147- class_name = "Fix" + "" .join ([p .title () for p in parts ])
149+ class_name = self . CLASS_PREFIX + "" .join ([p .title () for p in parts ])
148150 try :
149151 fix_class = getattr (mod , class_name )
150152 except AttributeError :
151- self .log_error ("Can't find %s.%s" ,
152- fix_name , class_name )
153- continue
154- try :
155- fixer = fix_class (self .options , self .fixer_log )
156- except Exception as err :
157- self .log_error ("Can't instantiate fixes.fix_%s.%s()" ,
158- fix_name , class_name , exc_info = True )
159- continue
153+ raise FixerError ("Can't find %s.%s" % (fix_name , class_name ))
154+ fixer = fix_class (self .options , self .fixer_log )
160155 if fixer .explicit and self .explicit is not True and \
161156 fix_mod_path not in self .explicit :
162157 self .log_message ("Skipping implicit fixer: %s" , fix_name )
@@ -168,17 +163,16 @@ def get_fixers(self):
168163 elif fixer .order == "post" :
169164 post_order_fixers .append (fixer )
170165 else :
171- raise ValueError ("Illegal fixer order: %r" % fixer .order )
166+ raise FixerError ("Illegal fixer order: %r" % fixer .order )
172167
173168 key_func = operator .attrgetter ("run_order" )
174169 pre_order_fixers .sort (key = key_func )
175170 post_order_fixers .sort (key = key_func )
176171 return (pre_order_fixers , post_order_fixers )
177172
178173 def log_error (self , msg , * args , ** kwds ):
179- """Increments error count and log a message."""
180- self .errors .append ((msg , args , kwds ))
181- self .logger .error (msg , * args , ** kwds )
174+ """Called when an error occurs."""
175+ raise
182176
183177 def log_message (self , msg , * args ):
184178 """Hook to log a message."""
@@ -191,13 +185,17 @@ def log_debug(self, msg, *args):
191185 msg = msg % args
192186 self .logger .debug (msg )
193187
188+ def print_output (self , lines ):
189+ """Called with lines of output to give to the user."""
190+ pass
191+
194192 def refactor (self , items , write = False , doctests_only = False ):
195193 """Refactor a list of files and directories."""
196194 for dir_or_file in items :
197195 if os .path .isdir (dir_or_file ):
198- self .refactor_dir (dir_or_file , write )
196+ self .refactor_dir (dir_or_file , write , doctests_only )
199197 else :
200- self .refactor_file (dir_or_file , write )
198+ self .refactor_file (dir_or_file , write , doctests_only )
201199
202200 def refactor_dir (self , dir_name , write = False , doctests_only = False ):
203201 """Descends down a directory and refactor every Python file found.
@@ -348,12 +346,11 @@ def processed_file(self, new_text, filename, old_text=None, write=False):
348346 if old_text == new_text :
349347 self .log_debug ("No changes to %s" , filename )
350348 return
351- diff_texts (old_text , new_text , filename )
352- if not write :
353- self .log_debug ("Not writing changes to %s" , filename )
354- return
349+ self .print_output (diff_texts (old_text , new_text , filename ))
355350 if write :
356351 self .write_file (new_text , filename , old_text )
352+ else :
353+ self .log_debug ("Not writing changes to %s" , filename )
357354
358355 def write_file (self , new_text , filename , old_text = None ):
359356 """Writes a string to a file.
@@ -528,10 +525,9 @@ def gen_lines(self, block, indent):
528525
529526
530527def diff_texts (a , b , filename ):
531- """Prints a unified diff of two strings."""
528+ """Return a unified diff of two strings."""
532529 a = a .splitlines ()
533530 b = b .splitlines ()
534- for line in difflib .unified_diff (a , b , filename , filename ,
535- "(original)" , "(refactored)" ,
536- lineterm = "" ):
537- print (line )
531+ return difflib .unified_diff (a , b , filename , filename ,
532+ "(original)" , "(refactored)" ,
533+ lineterm = "" )
0 commit comments