diff --git a/Cargo.lock b/Cargo.lock index b83e681dc3..5fb335bff5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1896,6 +1896,7 @@ dependencies = [ "which", "widestring", "winapi", + "windows", "winreg", ] @@ -2725,17 +2726,30 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c4bd0a50ac6020f65184721f758dba47bb9fbc2133df715ec74a237b26794a" +dependencies = [ + "windows_aarch64_msvc 0.39.0", + "windows_i686_gnu 0.39.0", + "windows_i686_msvc 0.39.0", + "windows_x86_64_gnu 0.39.0", + "windows_x86_64_msvc 0.39.0", +] + [[package]] name = "windows-sys" version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5acdd78cb4ba54c0045ac14f62d8f94a03d10047904ae2a40afa1e99d8f70825" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.34.0", + "windows_i686_gnu 0.34.0", + "windows_i686_msvc 0.34.0", + "windows_x86_64_gnu 0.34.0", + "windows_x86_64_msvc 0.34.0", ] [[package]] @@ -2744,30 +2758,60 @@ version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17cffbe740121affb56fad0fc0e421804adf0ae00891205213b5cecd30db881d" +[[package]] +name = "windows_aarch64_msvc" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec7711666096bd4096ffa835238905bb33fb87267910e154b18b44eaabb340f2" + [[package]] name = "windows_i686_gnu" version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2564fde759adb79129d9b4f54be42b32c89970c18ebf93124ca8870a498688ed" +[[package]] +name = "windows_i686_gnu" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763fc57100a5f7042e3057e7e8d9bdd7860d330070251a73d003563a3bb49e1b" + [[package]] name = "windows_i686_msvc" version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cd9d32ba70453522332c14d38814bceeb747d80b3958676007acadd7e166956" +[[package]] +name = "windows_i686_msvc" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bc7cbfe58828921e10a9f446fcaaf649204dcfe6c1ddd712c5eebae6bda1106" + [[package]] name = "windows_x86_64_gnu" version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfce6deae227ee8d356d19effc141a509cc503dfd1f850622ec4b0f84428e1f4" +[[package]] +name = "windows_x86_64_gnu" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6868c165637d653ae1e8dc4d82c25d4f97dd6605eaa8d784b5c6e0ab2a252b65" + [[package]] name = "windows_x86_64_msvc" version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d19538ccc21819d01deaf88d6a17eae6596a12e9aafdbb97916fb49896d89de9" +[[package]] +name = "windows_x86_64_msvc" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e4d40883ae9cae962787ca76ba76390ffa29214667a111db9e0a1ad8377e809" + [[package]] name = "winreg" version = "0.10.1" diff --git a/Lib/fnmatch.py b/Lib/fnmatch.py index af0dbcd092..fee59bf73f 100644 --- a/Lib/fnmatch.py +++ b/Lib/fnmatch.py @@ -9,16 +9,19 @@ The function translate(PATTERN) returns a regular expression corresponding to PATTERN. (It does not compile it.) """ -try: - import os -except ImportError: - import _dummy_os as os +import os import posixpath import re import functools __all__ = ["filter", "fnmatch", "fnmatchcase", "translate"] +# Build a thread-safe incrementing counter to help create unique regexp group +# names across calls. +from itertools import count +_nextgroupnum = count().__next__ +del count + def fnmatch(name, pat): """Test whether FILENAME matches PATTERN. @@ -49,7 +52,7 @@ def _compile_pattern(pat): return re.compile(res).match def filter(names, pat): - """Return the subset of the list NAMES that match PAT.""" + """Construct a list from those elements of the iterable NAMES that match PAT.""" result = [] pat = os.path.normcase(pat) match = _compile_pattern(pat) @@ -80,15 +83,19 @@ def translate(pat): There is no way to quote meta-characters. """ + STAR = object() + res = [] + add = res.append i, n = 0, len(pat) - res = '' while i < n: c = pat[i] i = i+1 if c == '*': - res = res + '.*' + # compress consecutive `*` into one + if (not res) or res[-1] is not STAR: + add(STAR) elif c == '?': - res = res + '.' + add('.') elif c == '[': j = i if j < n and pat[j] == '!': @@ -98,10 +105,10 @@ def translate(pat): while j < n and pat[j] != ']': j = j+1 if j >= n: - res = res + '\\[' + add('\\[') else: stuff = pat[i:j] - if '--' not in stuff: + if '-' not in stuff: stuff = stuff.replace('\\', r'\\') else: chunks = [] @@ -113,7 +120,16 @@ def translate(pat): chunks.append(pat[i:k]) i = k+1 k = k+3 - chunks.append(pat[i:j]) + chunk = pat[i:j] + if chunk: + chunks.append(chunk) + else: + chunks[-1] += '-' + # Remove empty ranges -- invalid in RE. + for k in range(len(chunks)-1, 0, -1): + if chunks[k-1][-1] > chunks[k][0]: + chunks[k-1] = chunks[k-1][:-1] + chunks[k][1:] + del chunks[k] # Escape backslashes and hyphens for set difference (--). # Hyphens that create ranges shouldn't be escaped. stuff = '-'.join(s.replace('\\', r'\\').replace('-', r'\-') @@ -121,11 +137,63 @@ def translate(pat): # Escape set operations (&&, ~~ and ||). stuff = re.sub(r'([&~|])', r'\\\1', stuff) i = j+1 - if stuff[0] == '!': - stuff = '^' + stuff[1:] - elif stuff[0] in ('^', '['): - stuff = '\\' + stuff - res = '%s[%s]' % (res, stuff) + if not stuff: + # Empty range: never match. + add('(?!)') + elif stuff == '!': + # Negated empty range: match any character. + add('.') + else: + if stuff[0] == '!': + stuff = '^' + stuff[1:] + elif stuff[0] in ('^', '['): + stuff = '\\' + stuff + add(f'[{stuff}]') + else: + add(re.escape(c)) + assert i == n + + # Deal with STARs. + inp = res + res = [] + add = res.append + i, n = 0, len(inp) + # Fixed pieces at the start? + while i < n and inp[i] is not STAR: + add(inp[i]) + i += 1 + # Now deal with STAR fixed STAR fixed ... + # For an interior `STAR fixed` pairing, we want to do a minimal + # .*? match followed by `fixed`, with no possibility of backtracking. + # We can't spell that directly, but can trick it into working by matching + # .*?fixed + # in a lookahead assertion, save the matched part in a group, then + # consume that group via a backreference. If the overall match fails, + # the lookahead assertion won't try alternatives. So the translation is: + # (?=(?P.*?fixed))(?P=name) + # Group names are created as needed: g0, g1, g2, ... + # The numbers are obtained from _nextgroupnum() to ensure they're unique + # across calls and across threads. This is because people rely on the + # undocumented ability to join multiple translate() results together via + # "|" to build large regexps matching "one of many" shell patterns. + while i < n: + assert inp[i] is STAR + i += 1 + if i == n: + add(".*") + break + assert inp[i] is not STAR + fixed = [] + while i < n and inp[i] is not STAR: + fixed.append(inp[i]) + i += 1 + fixed = "".join(fixed) + if i == n: + add(".*") + add(fixed) else: - res = res + re.escape(c) - return r'(?s:%s)\Z' % res + groupnum = _nextgroupnum() + add(f"(?=(?P.*?{fixed}))(?P=g{groupnum})") + assert i == n + res = "".join(res) + return fr'(?s:{res})\Z' diff --git a/Lib/ntpath.py b/Lib/ntpath.py index 6f771773a7..97edfa52aa 100644 --- a/Lib/ntpath.py +++ b/Lib/ntpath.py @@ -23,6 +23,7 @@ import genericpath from genericpath import * + __all__ = ["normcase","isabs","join","splitdrive","split","splitext", "basename","dirname","commonprefix","getsize","getmtime", "getatime","getctime", "islink","exists","lexists","isdir","isfile", @@ -41,14 +42,39 @@ def _get_bothseps(path): # Other normalizations (such as optimizing '../' away) are not done # (this is done by normpath). -def normcase(s): - """Normalize case of pathname. - - Makes all characters lowercase and all slashes into backslashes.""" - s = os.fspath(s) - if isinstance(s, bytes): - return s.replace(b'/', b'\\').lower() - else: +try: + from _winapi import ( + LCMapStringEx as _LCMapStringEx, + LOCALE_NAME_INVARIANT as _LOCALE_NAME_INVARIANT, + LCMAP_LOWERCASE as _LCMAP_LOWERCASE) + + def normcase(s): + """Normalize case of pathname. + + Makes all characters lowercase and all slashes into backslashes. + """ + s = os.fspath(s) + if not s: + return s + if isinstance(s, bytes): + encoding = sys.getfilesystemencoding() + s = s.decode(encoding, 'surrogateescape').replace('/', '\\') + s = _LCMapStringEx(_LOCALE_NAME_INVARIANT, + _LCMAP_LOWERCASE, s) + return s.encode(encoding, 'surrogateescape') + else: + return _LCMapStringEx(_LOCALE_NAME_INVARIANT, + _LCMAP_LOWERCASE, + s.replace('/', '\\')) +except ImportError: + def normcase(s): + """Normalize case of pathname. + + Makes all characters lowercase and all slashes into backslashes. + """ + s = os.fspath(s) + if isinstance(s, bytes): + return os.fsencode(os.fsdecode(s).replace('/', '\\').lower()) return s.replace('/', '\\').lower() @@ -312,12 +338,25 @@ def expanduser(path): drive = '' userhome = join(drive, os.environ['HOMEPATH']) + if i != 1: #~user + target_user = path[1:i] + if isinstance(target_user, bytes): + target_user = os.fsdecode(target_user) + current_user = os.environ.get('USERNAME') + + if target_user != current_user: + # Try to guess user home directory. By default all user + # profile directories are located in the same place and are + # named by corresponding usernames. If userhome isn't a + # normal profile directory, this guess is likely wrong, + # so we bail out. + if current_user != basename(userhome): + return path + userhome = join(dirname(userhome), target_user) + if isinstance(path, bytes): userhome = os.fsencode(userhome) - if i != 1: #~user - userhome = join(dirname(userhome), path[1:i]) - return userhome + path[i:] @@ -622,7 +661,7 @@ def _getfinalpathname_nonstrict(path): tail = join(name, tail) if tail else name return tail - def realpath(path): + def realpath(path, *, strict=False): path = normpath(path) if isinstance(path, bytes): prefix = b'\\\\?\\' @@ -647,6 +686,8 @@ def realpath(path): path = _getfinalpathname(path) initial_winerror = 0 except OSError as ex: + if strict: + raise initial_winerror = ex.winerror path = _getfinalpathname_nonstrict(path) # The path returned by _getfinalpathname will always start with \\?\ - diff --git a/Lib/test/test_fnmatch.py b/Lib/test/test_fnmatch.py index 55f9f0d3a5..04365341d8 100644 --- a/Lib/test/test_fnmatch.py +++ b/Lib/test/test_fnmatch.py @@ -2,6 +2,7 @@ import unittest import os +import string import warnings from fnmatch import fnmatch, fnmatchcase, translate, filter @@ -45,6 +46,13 @@ def test_fnmatch(self): check('\nfoo', 'foo*', False) check('\n', '*') + def test_slow_fnmatch(self): + check = self.check_match + check('a' * 50, '*a*a*a*a*a*a*a*a*a*a') + # The next "takes forever" if the regexp translation is + # straightforward. See bpo-40480. + check('a' * 50 + 'b', '*a*a*a*a*a*a*a*a*a*a', False) + def test_mix_bytes_str(self): self.assertRaises(TypeError, fnmatch, 'test', b'*') self.assertRaises(TypeError, fnmatch, b'test', '*') @@ -68,6 +76,11 @@ def test_bytes(self): self.check_match(b'test\xff', b'te*\xff') self.check_match(b'foo\nbar', b'foo*') + # TODO: RUSTPYTHON + import sys + if sys.platform == 'win32': + test_bytes = unittest.expectedFailure(test_bytes) + def test_case(self): ignorecase = os.path.normcase('ABC') == os.path.normcase('abc') check = self.check_match @@ -84,6 +97,119 @@ def test_sep(self): check('usr/bin', 'usr\\bin', normsep) check('usr\\bin', 'usr\\bin') + def test_char_set(self): + ignorecase = os.path.normcase('ABC') == os.path.normcase('abc') + check = self.check_match + tescases = string.ascii_lowercase + string.digits + string.punctuation + for c in tescases: + check(c, '[az]', c in 'az') + check(c, '[!az]', c not in 'az') + # Case insensitive. + for c in tescases: + check(c, '[AZ]', (c in 'az') and ignorecase) + check(c, '[!AZ]', (c not in 'az') or not ignorecase) + for c in string.ascii_uppercase: + check(c, '[az]', (c in 'AZ') and ignorecase) + check(c, '[!az]', (c not in 'AZ') or not ignorecase) + # Repeated same character. + for c in tescases: + check(c, '[aa]', c == 'a') + # Special cases. + for c in tescases: + check(c, '[^az]', c in '^az') + check(c, '[[az]', c in '[az') + check(c, r'[!]]', c != ']') + check('[', '[') + check('[]', '[]') + check('[!', '[!') + check('[!]', '[!]') + + def test_range(self): + ignorecase = os.path.normcase('ABC') == os.path.normcase('abc') + normsep = os.path.normcase('\\') == os.path.normcase('/') + check = self.check_match + tescases = string.ascii_lowercase + string.digits + string.punctuation + for c in tescases: + check(c, '[b-d]', c in 'bcd') + check(c, '[!b-d]', c not in 'bcd') + check(c, '[b-dx-z]', c in 'bcdxyz') + check(c, '[!b-dx-z]', c not in 'bcdxyz') + # Case insensitive. + for c in tescases: + check(c, '[B-D]', (c in 'bcd') and ignorecase) + check(c, '[!B-D]', (c not in 'bcd') or not ignorecase) + for c in string.ascii_uppercase: + check(c, '[b-d]', (c in 'BCD') and ignorecase) + check(c, '[!b-d]', (c not in 'BCD') or not ignorecase) + # Upper bound == lower bound. + for c in tescases: + check(c, '[b-b]', c == 'b') + # Special cases. + for c in tescases: + check(c, '[!-#]', c not in '-#') + check(c, '[!--.]', c not in '-.') + check(c, '[^-`]', c in '^_`') + if not (normsep and c == '/'): + check(c, '[[-^]', c in r'[\]^') + check(c, r'[\-^]', c in r'\]^') + check(c, '[b-]', c in '-b') + check(c, '[!b-]', c not in '-b') + check(c, '[-b]', c in '-b') + check(c, '[!-b]', c not in '-b') + check(c, '[-]', c in '-') + check(c, '[!-]', c not in '-') + # Upper bound is less that lower bound: error in RE. + for c in tescases: + check(c, '[d-b]', False) + check(c, '[!d-b]', True) + check(c, '[d-bx-z]', c in 'xyz') + check(c, '[!d-bx-z]', c not in 'xyz') + check(c, '[d-b^-`]', c in '^_`') + if not (normsep and c == '/'): + check(c, '[d-b[-^]', c in r'[\]^') + + def test_sep_in_char_set(self): + normsep = os.path.normcase('\\') == os.path.normcase('/') + check = self.check_match + check('/', r'[/]') + check('\\', r'[\]') + check('/', r'[\]', normsep) + check('\\', r'[/]', normsep) + check('[/]', r'[/]', False) + check(r'[\\]', r'[/]', False) + check('\\', r'[\t]') + check('/', r'[\t]', normsep) + check('t', r'[\t]') + check('\t', r'[\t]', False) + + def test_sep_in_range(self): + normsep = os.path.normcase('\\') == os.path.normcase('/') + check = self.check_match + check('a/b', 'a[.-0]b', not normsep) + check('a\\b', 'a[.-0]b', False) + check('a\\b', 'a[Z-^]b', not normsep) + check('a/b', 'a[Z-^]b', False) + + check('a/b', 'a[/-0]b', not normsep) + check(r'a\b', 'a[/-0]b', False) + check('a[/-0]b', 'a[/-0]b', False) + check(r'a[\-0]b', 'a[/-0]b', False) + + check('a/b', 'a[.-/]b') + check(r'a\b', 'a[.-/]b', normsep) + check('a[.-/]b', 'a[.-/]b', False) + check(r'a[.-\]b', 'a[.-/]b', False) + + check(r'a\b', r'a[\-^]b') + check('a/b', r'a[\-^]b', normsep) + check(r'a[\-^]b', r'a[\-^]b', False) + check('a[/-^]b', r'a[\-^]b', False) + + check(r'a\b', r'a[Z-\]b', not normsep) + check('a/b', r'a[Z-\]b', False) + check(r'a[Z-\]b', r'a[Z-\]b', False) + check('a[Z-/]b', r'a[Z-\]b', False) + def test_warnings(self): with warnings.catch_warnings(): warnings.simplefilter('error', Warning) @@ -99,6 +225,7 @@ def test_warnings(self): class TranslateTestCase(unittest.TestCase): def test_translate(self): + import re self.assertEqual(translate('*'), r'(?s:.*)\Z') self.assertEqual(translate('?'), r'(?s:.)\Z') self.assertEqual(translate('a?b*'), r'(?s:a.b.*)\Z') @@ -107,7 +234,34 @@ def test_translate(self): self.assertEqual(translate('[!x]'), r'(?s:[^x])\Z') self.assertEqual(translate('[^x]'), r'(?s:[\^x])\Z') self.assertEqual(translate('[x'), r'(?s:\[x)\Z') - + # from the docs + self.assertEqual(translate('*.txt'), r'(?s:.*\.txt)\Z') + # squash consecutive stars + self.assertEqual(translate('*********'), r'(?s:.*)\Z') + self.assertEqual(translate('A*********'), r'(?s:A.*)\Z') + self.assertEqual(translate('*********A'), r'(?s:.*A)\Z') + self.assertEqual(translate('A*********?[?]?'), r'(?s:A.*.[?].)\Z') + # fancy translation to prevent exponential-time match failure + t = translate('**a*a****a') + digits = re.findall(r'\d+', t) + self.assertEqual(len(digits), 4) + self.assertEqual(digits[0], digits[1]) + self.assertEqual(digits[2], digits[3]) + g1 = f"g{digits[0]}" # e.g., group name "g4" + g2 = f"g{digits[2]}" # e.g., group name "g5" + self.assertEqual(t, + fr'(?s:(?=(?P<{g1}>.*?a))(?P={g1})(?=(?P<{g2}>.*?a))(?P={g2}).*a)\Z') + # and try pasting multiple translate results - it's an undocumented + # feature that this works; all the pain of generating unique group + # names across calls exists to support this + r1 = translate('**a**a**a*') + r2 = translate('**b**b**b*') + r3 = translate('*c*c*c*') + fatre = "|".join([r1, r2, r3]) + self.assertTrue(re.match(fatre, 'abaccad')) + self.assertTrue(re.match(fatre, 'abxbcab')) + self.assertTrue(re.match(fatre, 'cbabcaxc')) + self.assertFalse(re.match(fatre, 'dabccbad')) class FilterTestCase(unittest.TestCase): diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py index 474a8e8b17..19ffdf5702 100644 --- a/Lib/test/test_ntpath.py +++ b/Lib/test/test_ntpath.py @@ -280,10 +280,6 @@ def test_realpath_strict(self): self.assertRaises(FileNotFoundError, ntpath.realpath, ABSTFN, strict=True) self.assertRaises(FileNotFoundError, ntpath.realpath, ABSTFN + "2", strict=True) - # TODO: RUSTPYTHON, TypeError: got an unexpected keyword argument 'strict' - if sys.platform == "win32": - test_realpath_strict = unittest.expectedFailure(test_realpath_strict) - @os_helper.skip_unless_symlink @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname') def test_realpath_relative(self): @@ -354,7 +350,6 @@ def test_realpath_broken_symlinks(self): self.assertPathEqual(ntpath.realpath(b"broken5"), os.fsencode(ABSTFN + r"\missing")) - @unittest.skip("TODO: RUSTPYTHON, leaves behind TESTFN") @os_helper.skip_unless_symlink @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname') def test_realpath_symlink_loops(self): @@ -446,10 +441,6 @@ def test_realpath_symlink_loops_strict(self): self.assertRaises(OSError, ntpath.realpath, ntpath.basename(ABSTFN), strict=True) - # TODO: RUSTPYTHON, FileExistsError: [Errno 183] Cannot create a file when that file already exists. (os error 183): 'None' -> 'None' - if sys.platform == "win32": - test_realpath_symlink_loops_strict = unittest.expectedFailure(test_realpath_symlink_loops_strict) - @os_helper.skip_unless_symlink @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname') def test_realpath_symlink_prefix(self): @@ -517,7 +508,6 @@ def test_realpath_cwd(self): with os_helper.change_cwd(test_dir_short): self.assertPathEqual(test_file_long, ntpath.realpath("file.txt")) - @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, ValueError: illegal environment variable name") def test_expandvars(self): with os_helper.EnvironmentVarGuard() as env: env.clear() @@ -544,7 +534,10 @@ def test_expandvars(self): tester('ntpath.expandvars("\'%foo%\'%bar")', "\'%foo%\'%bar") tester('ntpath.expandvars("bar\'%foo%")', "bar\'%foo%") - @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, ValueError: illegal environment variable name") + # TODO: RUSTPYTHON, ValueError: illegal environment variable name + if sys.platform == 'win32': + test_expandvars = unittest.expectedFailure(test_expandvars) + @unittest.skipUnless(os_helper.FS_NONASCII, 'need os_helper.FS_NONASCII') def test_expandvars_nonascii(self): def check(value, expected): @@ -565,8 +558,10 @@ def check(value, expected): check('%spam%bar', '%sbar' % nonascii) check('%{}%bar'.format(nonascii), 'ham%sbar' % nonascii) - # TODO: RUSTPYTHON - @unittest.expectedFailure + # TODO: RUSTPYTHON, ValueError: illegal environment variable name + if sys.platform == 'win32': + test_expandvars_nonascii = unittest.expectedFailure(test_expandvars_nonascii) + def test_expanduser(self): tester('ntpath.expanduser("test")', 'test') @@ -614,7 +609,9 @@ def test_expanduser(self): tester('ntpath.expanduser("~test")', '~test') tester('ntpath.expanduser("~")', 'C:\\Users\\eric') - + # TODO: RUSTPYTHON + if sys.platform == 'win32': + test_expanduser = unittest.expectedFailure(test_expanduser) @unittest.skipUnless(nt, "abspath requires 'nt' module") def test_abspath(self): @@ -809,14 +806,6 @@ class NtCommonTest(test_genericpath.CommonTest, unittest.TestCase): pathmodule = ntpath attributes = ['relpath'] - @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, ValueError: illegal environment variable name") - def test_expandvars(self): # TODO: RUSTPYTHON, remove when this passes - super().test_expandvars() # TODO: RUSTPYTHON, remove when this passes - - @unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, ValueError: illegal environment variable name") - def test_expandvars_nonascii(self): # TODO: RUSTPYTHON, remove when this passes - super().test_expandvars_nonascii() # TODO: RUSTPYTHON, remove when this passes - # TODO: RUSTPYTHON if sys.platform == "linux": @unittest.expectedFailure @@ -825,26 +814,34 @@ def test_nonascii_abspath(self): # TODO: RUSTPYTHON if sys.platform == "win32": + # TODO: RUSTPYTHON, ValueError: illegal environment variable name + @unittest.expectedFailure + def test_expandvars(self): # TODO: RUSTPYTHON; remove when done + super().test_expandvars() + + # TODO: RUSTPYTHON, ValueError: illegal environment variable name @unittest.expectedFailure - def test_samefile(self): + def test_expandvars_nonascii(self): # TODO: RUSTPYTHON; remove when done + super().test_expandvars_nonascii() + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_samefile(self): # TODO: RUSTPYTHON; remove when done super().test_samefile() - # TODO: RUSTPYTHON - if sys.platform == "win32": + # TODO: RUSTPYTHON @unittest.expectedFailure - def test_samefile_on_link(self): + def test_samefile_on_link(self): # TODO: RUSTPYTHON; remove when done super().test_samefile_on_link() - # TODO: RUSTPYTHON - if sys.platform == "win32": + # TODO: RUSTPYTHON @unittest.expectedFailure - def test_samestat(self): + def test_samestat(self): # TODO: RUSTPYTHON; remove when done super().test_samestat() - # TODO: RUSTPYTHON - if sys.platform == "win32": + # TODO: RUSTPYTHON @unittest.expectedFailure - def test_samestat_on_link(self): + def test_samestat_on_link(self): # TODO: RUSTPYTHON; remove when done super().test_samestat_on_link() diff --git a/vm/Cargo.toml b/vm/Cargo.toml index eba1c24628..f0d6d53676 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -112,6 +112,12 @@ winreg = "0.10.1" schannel = "0.1.19" widestring = "0.5.1" +[target.'cfg(windows)'.dependencies.windows] +version = "0.39" +features = [ + "Win32_UI_Shell", +] + [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 178f3e8a40..bce24f318d 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -21,7 +21,10 @@ pub(crate) mod module { }, PyResult, TryFromObject, VirtualMachine, }; - use std::{env, fs, io}; + use std::{ + env, fs, io, + os::windows::ffi::{OsStrExt, OsStringExt}, + }; use crate::builtins::PyDictRef; #[cfg(target_env = "msvc")] @@ -292,6 +295,44 @@ pub(crate) mod module { path.mode.process_path(buffer.to_os_string(), vm) } + #[pyfunction] + fn _path_splitroot(path: PyPathLike, vm: &VirtualMachine) -> PyResult<(String, String)> { + let orig: Vec<_> = path.path.into_os_string().encode_wide().collect(); + if orig.is_empty() { + return Ok(("".to_owned(), "".to_owned())); + } + let backslashed: Vec<_> = orig + .iter() + .copied() + .map(|c| if c == b'/' as u16 { b'\\' as u16 } else { c }) + .chain(std::iter::once(0)) // null-terminated + .collect(); + + fn from_utf16(wstr: &[u16], vm: &VirtualMachine) -> PyResult { + String::from_utf16(wstr).map_err(|e| vm.new_unicode_decode_error(e.to_string())) + } + + let wbuf = windows::core::PCWSTR::from_raw(backslashed.as_ptr()); + let (root, path) = match unsafe { windows::Win32::UI::Shell::PathCchSkipRoot(wbuf) } { + Ok(end) => { + assert!(!end.is_null()); + let len: usize = unsafe { end.as_ptr().offset_from(wbuf.as_ptr()) } + .try_into() + .expect("len must be non-negative"); + assert!( + len < backslashed.len(), // backslashed is null-terminated + "path: {:?} {} < {}", + std::path::PathBuf::from(std::ffi::OsString::from_wide(&backslashed)), + len, + backslashed.len() + ); + (from_utf16(&orig[..len], vm)?, from_utf16(&orig[len..], vm)?) + } + Err(_) => ("".to_owned(), from_utf16(&orig, vm)?), + }; + Ok((root, path)) + } + #[pyfunction] fn _getdiskusage(path: PyPathLike, vm: &VirtualMachine) -> PyResult<(u64, u64)> { use um::fileapi::GetDiskFreeSpaceExW;