|
13 | 13 | # Local imports |
14 | 14 | from .. import fixer_base |
15 | 15 | from os.path import dirname, join, exists, pathsep |
16 | | -from ..fixer_util import FromImport, syms |
| 16 | +from ..fixer_util import FromImport, syms, token |
| 17 | + |
| 18 | + |
| 19 | +def traverse_imports(names): |
| 20 | + """ |
| 21 | + Walks over all the names imported in a dotted_as_names node. |
| 22 | + """ |
| 23 | + pending = [names] |
| 24 | + while pending: |
| 25 | + node = pending.pop() |
| 26 | + if node.type == token.NAME: |
| 27 | + yield node.value |
| 28 | + elif node.type == syms.dotted_name: |
| 29 | + yield "".join([ch.value for ch in node.children]) |
| 30 | + elif node.type == syms.dotted_as_name: |
| 31 | + pending.append(node.children[0]) |
| 32 | + elif node.type == syms.dotted_as_names: |
| 33 | + pending.extend(node.children[::-2]) |
| 34 | + else: |
| 35 | + raise AssertionError("unkown node type") |
| 36 | + |
17 | 37 |
|
18 | 38 | class FixImport(fixer_base.BaseFix): |
19 | 39 |
|
20 | 40 | PATTERN = """ |
21 | | - import_from< type='from' imp=any 'import' ['('] any [')'] > |
| 41 | + import_from< 'from' imp=any 'import' ['('] any [')'] > |
22 | 42 | | |
23 | | - import_name< type='import' imp=any > |
| 43 | + import_name< 'import' imp=any > |
24 | 44 | """ |
25 | 45 |
|
26 | 46 | def transform(self, node, results): |
27 | 47 | imp = results['imp'] |
28 | 48 |
|
29 | | - mod_name = str(imp.children[0] if imp.type == syms.dotted_as_name \ |
30 | | - else imp) |
31 | | - |
32 | | - if str(imp).startswith('.'): |
33 | | - # Already a new-style import |
34 | | - return |
35 | | - |
36 | | - if not probably_a_local_import(str(mod_name), self.filename): |
37 | | - # I guess this is a global import -- skip it! |
38 | | - return |
39 | | - |
40 | | - if results['type'].value == 'from': |
| 49 | + if node.type == syms.import_from: |
41 | 50 | # Some imps are top-level (eg: 'import ham') |
42 | 51 | # some are first level (eg: 'import ham.eggs') |
43 | 52 | # some are third level (eg: 'import ham.eggs as spam') |
44 | 53 | # Hence, the loop |
45 | 54 | while not hasattr(imp, 'value'): |
46 | 55 | imp = imp.children[0] |
47 | | - imp.value = "." + imp.value |
48 | | - node.changed() |
| 56 | + if self.probably_a_local_import(imp.value): |
| 57 | + imp.value = "." + imp.value |
| 58 | + imp.changed() |
| 59 | + return node |
49 | 60 | else: |
50 | | - new = FromImport('.', getattr(imp, 'content', None) or [imp]) |
| 61 | + have_local = False |
| 62 | + have_absolute = False |
| 63 | + for mod_name in traverse_imports(imp): |
| 64 | + if self.probably_a_local_import(mod_name): |
| 65 | + have_local = True |
| 66 | + else: |
| 67 | + have_absolute = True |
| 68 | + if have_absolute: |
| 69 | + if have_local: |
| 70 | + # We won't handle both sibling and absolute imports in the |
| 71 | + # same statement at the moment. |
| 72 | + self.warning(node, "absolute and local imports together") |
| 73 | + return |
| 74 | + |
| 75 | + new = FromImport('.', [imp]) |
51 | 76 | new.set_prefix(node.get_prefix()) |
52 | | - node = new |
53 | | - return node |
| 77 | + return new |
54 | 78 |
|
55 | | -def probably_a_local_import(imp_name, file_path): |
56 | | - # Must be stripped because the right space is included by the parser |
57 | | - imp_name = imp_name.split('.', 1)[0].strip() |
58 | | - base_path = dirname(file_path) |
59 | | - base_path = join(base_path, imp_name) |
60 | | - # If there is no __init__.py next to the file its not in a package |
61 | | - # so can't be a relative import. |
62 | | - if not exists(join(dirname(base_path), '__init__.py')): |
| 79 | + def probably_a_local_import(self, imp_name): |
| 80 | + imp_name = imp_name.split('.', 1)[0] |
| 81 | + base_path = dirname(self.filename) |
| 82 | + base_path = join(base_path, imp_name) |
| 83 | + # If there is no __init__.py next to the file its not in a package |
| 84 | + # so can't be a relative import. |
| 85 | + if not exists(join(dirname(base_path), '__init__.py')): |
| 86 | + return False |
| 87 | + for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: |
| 88 | + if exists(base_path + ext): |
| 89 | + return True |
63 | 90 | return False |
64 | | - for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: |
65 | | - if exists(base_path + ext): |
66 | | - return True |
67 | | - return False |
|
0 commit comments