Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f4db28b

Browse files
committed
multi-import case
1 parent 6cf9f38 commit f4db28b

2 files changed

Lines changed: 103 additions & 26 deletions

File tree

IPython/extensions/autoreload.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,18 @@ def __call__(self):
472472
class ImportFromTracker:
473473
def __init__(self, imports_froms, symbol_map):
474474
self.imports_froms = imports_froms
475-
self.symbol_map = symbol_map
475+
# symbol_map maps original_name -> list of resolved_names
476+
self.symbol_map = {}
477+
if symbol_map:
478+
for module_name, mappings in symbol_map.items():
479+
self.symbol_map[module_name] = {}
480+
for original_name, resolved_names in mappings.items():
481+
if isinstance(resolved_names, list):
482+
self.symbol_map[module_name][original_name] = resolved_names[:]
483+
else:
484+
self.symbol_map[module_name][original_name] = [resolved_names]
485+
else:
486+
self.symbol_map = symbol_map or {}
476487

477488
def add_import(self, module_name, original_name, resolved_name):
478489
"""Add an import, handling conflicts with existing imports.
@@ -484,25 +495,29 @@ def add_import(self, module_name, original_name, resolved_name):
484495
if module_name not in self.symbol_map:
485496
self.symbol_map[module_name] = {}
486497

487-
# Check if there's already a different mapping for the same resolved_name
488-
existing_mapping = None
489-
for orig_name, res_name in self.symbol_map[module_name].items():
490-
if res_name == resolved_name and orig_name != original_name:
491-
existing_mapping = orig_name
492-
break
493-
494-
# If there's a conflict, the newer import takes precedence since it just executed successfully
495-
if existing_mapping is not None:
496-
# Remove the old mapping
497-
if existing_mapping in self.imports_froms[module_name]:
498-
self.imports_froms[module_name].remove(existing_mapping)
499-
if existing_mapping in self.symbol_map[module_name]:
500-
del self.symbol_map[module_name][existing_mapping]
498+
# Check if there's already a different mapping for the same resolved_name from a different original_name
499+
# We need to remove any conflicting mappings
500+
for orig_name, res_names in list(self.symbol_map[module_name].items()):
501+
if resolved_name in res_names and orig_name != original_name:
502+
# Remove the conflicting resolved_name from the other original_name's list
503+
res_names.remove(resolved_name)
504+
if (
505+
not res_names
506+
): # If the list is now empty, remove the original_name entirely
507+
if orig_name in self.imports_froms[module_name]:
508+
self.imports_froms[module_name].remove(orig_name)
509+
del self.symbol_map[module_name][orig_name]
501510

502511
# Add the new mapping
503512
if original_name not in self.imports_froms[module_name]:
504513
self.imports_froms[module_name].append(original_name)
505-
self.symbol_map[module_name][original_name] = resolved_name
514+
515+
if original_name not in self.symbol_map[module_name]:
516+
self.symbol_map[module_name][original_name] = []
517+
518+
# Add the resolved_name if it's not already in the list
519+
if resolved_name not in self.symbol_map[module_name][original_name]:
520+
self.symbol_map[module_name][original_name].append(resolved_name)
506521

507522

508523
def append_obj(module, d, name, obj, autoload=False):
@@ -586,9 +601,14 @@ def superreload(
586601
)
587602
):
588603
continue
604+
605+
# Handle symbol mapping - now supporting multiple resolved names per original name
589606
if symbol_map and name in symbol_map.get(module.__name__, {}):
590-
name = symbol_map.get(module.__name__, {})[name]
591-
shell.user_ns[name] = new_obj
607+
resolved_names = symbol_map.get(module.__name__, {})[name]
608+
for resolved_name in resolved_names:
609+
shell.user_ns[resolved_name] = new_obj
610+
else:
611+
shell.user_ns[name] = new_obj
592612

593613
new_refs = []
594614
for old_ref in old_objects[key]:

tests/test_zzz_autoreload.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -922,15 +922,15 @@ def test_import_from_tracker_conflict_resolution(self):
922922
# Verify initial state
923923
assert mod_name in tracker.imports_froms
924924
assert "foo" in tracker.imports_froms[mod_name]
925-
assert tracker.symbol_map[mod_name]["foo"] == "bar"
925+
assert tracker.symbol_map[mod_name]["foo"] == ["bar"]
926926

927927
# Second import: from mod_name import bar (conflicts with previous "bar")
928928
tracker.add_import(mod_name, "bar", "bar")
929929

930930
# The second import should take precedence since "bar" is a valid import
931931
assert "bar" in tracker.imports_froms[mod_name]
932932
assert "foo" not in tracker.imports_froms[mod_name] # Should be removed
933-
assert tracker.symbol_map[mod_name]["bar"] == "bar"
933+
assert tracker.symbol_map[mod_name]["bar"] == ["bar"]
934934
assert "foo" not in tracker.symbol_map[mod_name] # Should be removed
935935
finally:
936936
# Clean up sys.modules
@@ -967,15 +967,15 @@ def test_import_from_tracker_reverse_conflict(self):
967967

968968
# Verify initial state
969969
assert "bar" in tracker.imports_froms[mod_name]
970-
assert tracker.symbol_map[mod_name]["bar"] == "bar"
970+
assert tracker.symbol_map[mod_name]["bar"] == ["bar"]
971971

972972
# Second import: from mod_name import foo as bar (conflicts with previous "bar")
973973
tracker.add_import(mod_name, "foo", "bar")
974974

975975
# The second import should take precedence since "foo" is a valid import
976976
assert "foo" in tracker.imports_froms[mod_name]
977977
assert "bar" not in tracker.imports_froms[mod_name] # Should be removed
978-
assert tracker.symbol_map[mod_name]["foo"] == "bar"
978+
assert tracker.symbol_map[mod_name]["foo"] == ["bar"]
979979
assert "bar" not in tracker.symbol_map[mod_name] # Should be removed
980980
finally:
981981
# Clean up sys.modules
@@ -1004,7 +1004,7 @@ def test_import_from_tracker_invalid_import(self):
10041004

10051005
# Verify initial state
10061006
assert "foo" in tracker.imports_froms[mod_name]
1007-
assert tracker.symbol_map[mod_name]["foo"] == "bar"
1007+
assert tracker.symbol_map[mod_name]["foo"] == ["bar"]
10081008

10091009
# Second import: from mod_name import foo2 as bar (conflicting import)
10101010
# In the new approach, this would only be called if the import actually succeeded
@@ -1014,7 +1014,7 @@ def test_import_from_tracker_invalid_import(self):
10141014
# The new mapping should replace the old one since it's more recent
10151015
assert "foo2" in tracker.imports_froms[mod_name]
10161016
assert "foo" not in tracker.imports_froms[mod_name] # Should be replaced
1017-
assert tracker.symbol_map[mod_name]["foo2"] == "bar"
1017+
assert tracker.symbol_map[mod_name]["foo2"] == ["bar"]
10181018
assert "foo" not in tracker.symbol_map[mod_name] # Should be replaced
10191019

10201020
def test_import_from_tracker_integration(self):
@@ -1059,6 +1059,40 @@ def test_import_from_tracker_integration(self):
10591059
# The 'bar' variable should now contain the modified 'bar', not 'foo'
10601060
assert self.shell.user_ns["bar"] == "modified_bar"
10611061

1062+
def test_autoreload3_double_import(self):
1063+
"""Test the integration of ImportFromTracker with autoreload"""
1064+
# Create a test module
1065+
mod_name, mod_fn = self.new_module(
1066+
textwrap.dedent(
1067+
"""
1068+
foo = "original_foo"
1069+
bar = "original_bar"
1070+
"""
1071+
)
1072+
)
1073+
# Enable autoreload mode 3 (complete)
1074+
self.shell.magic_autoreload("3")
1075+
1076+
# First import: from mod_name import foo as bar
1077+
# This will naturally load the module into sys.modules
1078+
self.shell.run_code(f"from {mod_name} import foo as bar")
1079+
self.shell.run_code(f"from {mod_name} import foo")
1080+
assert self.shell.user_ns["bar"] == "original_foo"
1081+
assert self.shell.user_ns["foo"] == "original_foo"
1082+
# Modify the module
1083+
self.write_file(
1084+
mod_fn,
1085+
textwrap.dedent(
1086+
"""
1087+
foo = "modified_foo"
1088+
bar = "modified_bar"
1089+
"""
1090+
),
1091+
)
1092+
self.shell.run_code("pass")
1093+
assert self.shell.user_ns["bar"] == "modified_foo"
1094+
assert self.shell.user_ns["foo"] == "modified_foo"
1095+
10621096
def test_import_from_tracker_unloaded_module(self):
10631097
"""Test that ImportFromTracker works with the post-execution approach"""
10641098
from IPython.extensions.autoreload import ImportFromTracker
@@ -1080,7 +1114,7 @@ def test_import_from_tracker_unloaded_module(self):
10801114
assert fake_mod_name in tracker.imports_froms
10811115
assert fake_mod_name in tracker.symbol_map
10821116
assert "some_attr" in tracker.imports_froms[fake_mod_name]
1083-
assert tracker.symbol_map[fake_mod_name]["some_attr"] == "some_name"
1117+
assert tracker.symbol_map[fake_mod_name]["some_attr"] == ["some_name"]
10841118

10851119
# Simulate a conflict scenario - another import succeeded
10861120
tracker.add_import(fake_mod_name, "another_attr", "some_name")
@@ -1092,7 +1126,30 @@ def test_import_from_tracker_unloaded_module(self):
10921126
assert (
10931127
"some_attr" not in tracker.imports_froms[fake_mod_name]
10941128
) # Should be replaced
1095-
assert tracker.symbol_map[fake_mod_name]["another_attr"] == "some_name"
1129+
assert tracker.symbol_map[fake_mod_name]["another_attr"] == ["some_name"]
10961130
assert (
10971131
"some_attr" not in tracker.symbol_map[fake_mod_name]
10981132
) # Should be replaced
1133+
1134+
def test_import_from_tracker_multiple_resolved_names(self):
1135+
"""Test that the same original name can map to multiple resolved names"""
1136+
from IPython.extensions.autoreload import ImportFromTracker
1137+
1138+
# Create a tracker
1139+
tracker = ImportFromTracker({}, {})
1140+
1141+
fake_mod_name = "test_module_abc"
1142+
1143+
# Simulate: from test_module_abc import foo as bar
1144+
tracker.add_import(fake_mod_name, "foo", "bar")
1145+
1146+
# Verify initial state
1147+
assert "foo" in tracker.imports_froms[fake_mod_name]
1148+
assert tracker.symbol_map[fake_mod_name]["foo"] == ["bar"]
1149+
1150+
# Simulate: from test_module_abc import foo (same original name, different resolved name)
1151+
tracker.add_import(fake_mod_name, "foo", "foo")
1152+
1153+
# Both resolved names should be tracked for the same original name
1154+
assert "foo" in tracker.imports_froms[fake_mod_name]
1155+
assert set(tracker.symbol_map[fake_mod_name]["foo"]) == {"bar", "foo"}

0 commit comments

Comments
 (0)