diff --git a/python/bin/repl_stub.py b/python/bin/repl_stub.py index 86452aa869..1e21b26dc3 100644 --- a/python/bin/repl_stub.py +++ b/python/bin/repl_stub.py @@ -13,6 +13,9 @@ The logic for PYTHONSTARTUP is handled in python/private/repl_template.py. """ +# Capture the globals from PYTHONSTARTUP so we can pass them on to the console. +console_locals = globals().copy() + import code import sys @@ -26,4 +29,4 @@ sys.ps2 = "" # We set the banner to an empty string because the repl_template.py file already prints the banner. -code.interact(banner="", exitmsg=exitmsg) +code.interact(local=console_locals, banner="", exitmsg=exitmsg) diff --git a/python/private/repl_template.py b/python/private/repl_template.py index 0e058b23ae..37f4529fbe 100644 --- a/python/private/repl_template.py +++ b/python/private/repl_template.py @@ -14,6 +14,10 @@ def start_repl(): cprt = 'Type "help", "copyright", "credits" or "license" for more information.' sys.stderr.write("Python %s on %s\n%s\n" % (sys.version, sys.platform, cprt)) + # If there's a PYTHONSTARTUP script, we need to capture the new variables + # that it defines. + new_globals = {} + # Simulate Python's behavior when a valid startup script is defined by the # PYTHONSTARTUP variable. If this file path fails to load, print the error # and revert to the default behavior. @@ -27,10 +31,14 @@ def start_repl(): print(f"{type(error).__name__}: {error}") else: compiled_code = compile(source_code, filename=startup_file, mode="exec") - eval(compiled_code, {}) + eval(compiled_code, new_globals) bazel_runfiles = runfiles.Create() - runpy.run_path(bazel_runfiles.Rlocation(STUB_PATH), run_name="__main__") + runpy.run_path( + bazel_runfiles.Rlocation(STUB_PATH), + init_globals=new_globals, + run_name="__main__", + ) if __name__ == "__main__": diff --git a/tests/repl/repl_test.py b/tests/repl/repl_test.py index 51ca951110..37c9a37a0d 100644 --- a/tests/repl/repl_test.py +++ b/tests/repl/repl_test.py @@ -1,7 +1,9 @@ import os import subprocess import sys +import tempfile import unittest +from pathlib import Path from typing import Iterable from python import runfiles @@ -13,18 +15,26 @@ EXPECT_TEST_MODULE_IMPORTABLE = os.environ["EXPECT_TEST_MODULE_IMPORTABLE"] == "1" +# An arbitrary piece of code that sets some kind of variable. The variable needs to persist into the +# actual shell. +PYTHONSTARTUP_SETS_VAR = """\ +foo = 1234 +""" + + class ReplTest(unittest.TestCase): def setUp(self): self.repl = rfiles.Rlocation("rules_python/python/bin/repl") assert self.repl - def run_code_in_repl(self, lines: Iterable[str]) -> str: + def run_code_in_repl(self, lines: Iterable[str], *, env=None) -> str: """Runs the lines of code in the REPL and returns the text output.""" return subprocess.check_output( [self.repl], text=True, stderr=subprocess.STDOUT, input="\n".join(lines), + env=env, ).strip() def test_repl_version(self): @@ -69,6 +79,44 @@ def test_import_test_module_failure(self): ) self.assertIn("ModuleNotFoundError: No module named 'test_module'", result) + def test_pythonstartup_gets_executed(self): + """Validates that we can use the variables from PYTHONSTARTUP in the console itself.""" + with tempfile.TemporaryDirectory() as tempdir: + pythonstartup = Path(tempdir) / "pythonstartup.py" + pythonstartup.write_text(PYTHONSTARTUP_SETS_VAR) + + env = os.environ.copy() + env["PYTHONSTARTUP"] = str(pythonstartup) + + result = self.run_code_in_repl( + [ + "print(f'The value of foo is {foo}')", + ], + env=env, + ) + + self.assertIn("The value of foo is 1234", result) + + def test_pythonstartup_doesnt_leak(self): + """Validates that we don't accidentally leak code into the console. + + This test validates that a few of the variables we use in the template and stub are not + accessible in the REPL itself. + """ + with tempfile.TemporaryDirectory() as tempdir: + pythonstartup = Path(tempdir) / "pythonstartup.py" + pythonstartup.write_text(PYTHONSTARTUP_SETS_VAR) + + env = os.environ.copy() + env["PYTHONSTARTUP"] = str(pythonstartup) + + for var_name in ("exitmsg", "sys", "code", "bazel_runfiles", "STUB_PATH"): + with self.subTest(var_name=var_name): + result = self.run_code_in_repl([f"print({var_name})"], env=env) + self.assertIn( + f"NameError: name '{var_name}' is not defined", result + ) + if __name__ == "__main__": unittest.main()