|
13 | 13 | import warnings |
14 | 14 | import unittest |
15 | 15 | import importlib |
| 16 | +import collections |
16 | 17 |
|
17 | 18 | __all__ = ["Error", "TestFailed", "ResourceDenied", "import_module", |
18 | 19 | "verbose", "use_resources", "max_memuse", "record_original_stdout", |
@@ -510,37 +511,57 @@ def __exit__(self, *ignore_exc): |
510 | 511 | sys.modules.update(self.original_modules) |
511 | 512 |
|
512 | 513 |
|
513 | | -class EnvironmentVarGuard(object): |
| 514 | +class EnvironmentVarGuard(collections.MutableMapping): |
514 | 515 |
|
515 | 516 | """Class to help protect the environment variable properly. Can be used as |
516 | 517 | a context manager.""" |
517 | 518 |
|
518 | 519 | def __init__(self): |
| 520 | + self._environ = os.environ |
519 | 521 | self._changed = {} |
520 | 522 |
|
521 | | - def set(self, envvar, value): |
| 523 | + def __getitem__(self, envvar): |
| 524 | + return self._environ[envvar] |
| 525 | + |
| 526 | + def __setitem__(self, envvar, value): |
522 | 527 | # Remember the initial value on the first access |
523 | 528 | if envvar not in self._changed: |
524 | | - self._changed[envvar] = os.environ.get(envvar) |
525 | | - os.environ[envvar] = value |
| 529 | + self._changed[envvar] = self._environ.get(envvar) |
| 530 | + self._environ[envvar] = value |
526 | 531 |
|
527 | | - def unset(self, envvar): |
| 532 | + def __delitem__(self, envvar): |
528 | 533 | # Remember the initial value on the first access |
529 | 534 | if envvar not in self._changed: |
530 | | - self._changed[envvar] = os.environ.get(envvar) |
531 | | - if envvar in os.environ: |
532 | | - del os.environ[envvar] |
| 535 | + self._changed[envvar] = self._environ.get(envvar) |
| 536 | + if envvar in self._environ: |
| 537 | + del self._environ[envvar] |
| 538 | + |
| 539 | + def keys(self): |
| 540 | + return self._environ.keys() |
| 541 | + |
| 542 | + def __iter__(self): |
| 543 | + return iter(self._environ) |
| 544 | + |
| 545 | + def __len__(self): |
| 546 | + return len(self._environ) |
| 547 | + |
| 548 | + def set(self, envvar, value): |
| 549 | + self[envvar] = value |
| 550 | + |
| 551 | + def unset(self, envvar): |
| 552 | + del self[envvar] |
533 | 553 |
|
534 | 554 | def __enter__(self): |
535 | 555 | return self |
536 | 556 |
|
537 | 557 | def __exit__(self, *ignore_exc): |
538 | 558 | for (k, v) in self._changed.items(): |
539 | 559 | if v is None: |
540 | | - if k in os.environ: |
541 | | - del os.environ[k] |
| 560 | + if k in self._environ: |
| 561 | + del self._environ[k] |
542 | 562 | else: |
543 | | - os.environ[k] = v |
| 563 | + self._environ[k] = v |
| 564 | + |
544 | 565 |
|
545 | 566 | class TransientResource(object): |
546 | 567 |
|
|
0 commit comments