diff --git a/Lib/cgitb.py b/Lib/cgitb.py index 8ce0e833a9..f6b97f25c5 100644 --- a/Lib/cgitb.py +++ b/Lib/cgitb.py @@ -74,7 +74,7 @@ def lookup(name, frame, locals): return 'global', frame.f_globals[name] if '__builtins__' in frame.f_globals: builtins = frame.f_globals['__builtins__'] - if type(builtins) is type({}): + if isinstance(builtins, dict): if name in builtins: return 'builtin', builtins[name] else: diff --git a/Lib/colorsys.py b/Lib/colorsys.py index 9bdc83e377..bc897bd0f9 100644 --- a/Lib/colorsys.py +++ b/Lib/colorsys.py @@ -83,7 +83,7 @@ def rgb_to_hls(r, g, b): if l <= 0.5: s = rangec / sumc else: - s = rangec / (2.0-sumc) + s = rangec / (2.0-maxc-minc) # Not always 2.0-sumc: gh-106498. rc = (maxc-r) / rangec gc = (maxc-g) / rangec bc = (maxc-b) / rangec diff --git a/Lib/csv.py b/Lib/csv.py index 2f38bb1a19..77f30c8d2b 100644 --- a/Lib/csv.py +++ b/Lib/csv.py @@ -4,17 +4,22 @@ """ import re -from _csv import Error, writer, reader, \ +import types +from _csv import Error, __version__, writer, reader, register_dialect, \ + unregister_dialect, get_dialect, list_dialects, \ + field_size_limit, \ QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, \ + QUOTE_STRINGS, QUOTE_NOTNULL, \ __doc__ +from _csv import Dialect as _Dialect -from collections import OrderedDict from io import StringIO __all__ = ["QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE", + "QUOTE_STRINGS", "QUOTE_NOTNULL", "Error", "Dialect", "__doc__", "excel", "excel_tab", "field_size_limit", "reader", "writer", - "Sniffer", + "register_dialect", "get_dialect", "list_dialects", "Sniffer", "unregister_dialect", "__version__", "DictReader", "DictWriter", "unix_dialect"] @@ -57,10 +62,12 @@ class excel(Dialect): skipinitialspace = False lineterminator = '\r\n' quoting = QUOTE_MINIMAL +register_dialect("excel", excel) class excel_tab(excel): """Describe the usual properties of Excel-generated TAB-delimited files.""" delimiter = '\t' +register_dialect("excel-tab", excel_tab) class unix_dialect(Dialect): """Describe the usual properties of Unix-generated CSV files.""" @@ -70,11 +77,14 @@ class unix_dialect(Dialect): skipinitialspace = False lineterminator = '\n' quoting = QUOTE_ALL +register_dialect("unix", unix_dialect) class DictReader: def __init__(self, f, fieldnames=None, restkey=None, restval=None, dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self._fieldnames = fieldnames # list of keys for the dict self.restkey = restkey # key to catch long rows self.restval = restval # default value for short rows @@ -111,7 +121,7 @@ def __next__(self): # values while row == []: row = next(self.reader) - d = OrderedDict(zip(self.fieldnames, row)) + d = dict(zip(self.fieldnames, row)) lf = len(self.fieldnames) lr = len(row) if lf < lr: @@ -121,13 +131,18 @@ def __next__(self): d[key] = self.restval return d + __class_getitem__ = classmethod(types.GenericAlias) + class DictWriter: def __init__(self, f, fieldnames, restval="", extrasaction="raise", dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self.fieldnames = fieldnames # list of keys for the dict self.restval = restval # for writing short dicts - if extrasaction.lower() not in ("raise", "ignore"): + extrasaction = extrasaction.lower() + if extrasaction not in ("raise", "ignore"): raise ValueError("extrasaction (%s) must be 'raise' or 'ignore'" % extrasaction) self.extrasaction = extrasaction @@ -135,7 +150,7 @@ def __init__(self, f, fieldnames, restval="", extrasaction="raise", def writeheader(self): header = dict(zip(self.fieldnames, self.fieldnames)) - self.writerow(header) + return self.writerow(header) def _dict_to_list(self, rowdict): if self.extrasaction == "raise": @@ -151,11 +166,8 @@ def writerow(self, rowdict): def writerows(self, rowdicts): return self.writer.writerows(map(self._dict_to_list, rowdicts)) -# Guard Sniffer's type checking against builds that exclude complex() -try: - complex -except NameError: - complex = float + __class_getitem__ = classmethod(types.GenericAlias) + class Sniffer: ''' @@ -404,14 +416,10 @@ def has_header(self, sample): continue # skip rows that have irregular number of columns for col in list(columnTypes.keys()): - - for thisType in [int, float, complex]: - try: - thisType(row[col]) - break - except (ValueError, OverflowError): - pass - else: + thisType = complex + try: + thisType(row[col]) + except (ValueError, OverflowError): # fallback to length of string thisType = len(row[col]) @@ -427,7 +435,7 @@ def has_header(self, sample): # on whether it's a header hasHeader = 0 for col, colType in columnTypes.items(): - if type(colType) == type(0): # it's a length + if isinstance(colType, int): # it's a length if len(header[col]) != colType: hasHeader += 1 else: diff --git a/Lib/ftplib.py b/Lib/ftplib.py index 7c5a50715f..a56e0c3085 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -434,10 +434,7 @@ def retrbinary(self, cmd, callback, blocksize=8192, rest=None): """ self.voidcmd('TYPE I') with self.transfercmd(cmd, rest) as conn: - while 1: - data = conn.recv(blocksize) - if not data: - break + while data := conn.recv(blocksize): callback(data) # shutdown ssl layer if _SSLSocket is not None and isinstance(conn, _SSLSocket): @@ -496,10 +493,7 @@ def storbinary(self, cmd, fp, blocksize=8192, callback=None, rest=None): """ self.voidcmd('TYPE I') with self.transfercmd(cmd, rest) as conn: - while 1: - buf = fp.read(blocksize) - if not buf: - break + while buf := fp.read(blocksize): conn.sendall(buf) if callback: callback(buf) @@ -561,7 +555,7 @@ def dir(self, *args): LIST command. (This *should* only be used for a pathname.)''' cmd = 'LIST' func = None - if args[-1:] and type(args[-1]) != type(''): + if args[-1:] and not isinstance(args[-1], str): args, func = args[:-1], args[-1] for arg in args: if arg: @@ -713,28 +707,12 @@ class FTP_TLS(FTP): '221 Goodbye.' >>> ''' - ssl_version = ssl.PROTOCOL_TLS_CLIENT def __init__(self, host='', user='', passwd='', acct='', - keyfile=None, certfile=None, context=None, - timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None, *, - encoding='utf-8'): - if context is not None and keyfile is not None: - raise ValueError("context and keyfile arguments are mutually " - "exclusive") - if context is not None and certfile is not None: - raise ValueError("context and certfile arguments are mutually " - "exclusive") - if keyfile is not None or certfile is not None: - import warnings - warnings.warn("keyfile and certfile are deprecated, use a " - "custom context instead", DeprecationWarning, 2) - self.keyfile = keyfile - self.certfile = certfile + *, context=None, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None, encoding='utf-8'): if context is None: - context = ssl._create_stdlib_context(self.ssl_version, - certfile=certfile, - keyfile=keyfile) + context = ssl._create_stdlib_context() self.context = context self._prot_p = False super().__init__(host, user, passwd, acct, @@ -749,7 +727,7 @@ def auth(self): '''Set up secure control connection by using TLS/SSL.''' if isinstance(self.sock, ssl.SSLSocket): raise ValueError("Already using TLS") - if self.ssl_version >= ssl.PROTOCOL_TLS: + if self.context.protocol >= ssl.PROTOCOL_TLS: resp = self.voidcmd('AUTH TLS') else: resp = self.voidcmd('AUTH SSL') diff --git a/Lib/pprint.py b/Lib/pprint.py index 575688d8eb..34ed12637e 100644 --- a/Lib/pprint.py +++ b/Lib/pprint.py @@ -637,19 +637,6 @@ def _recursion(object): % (type(object).__name__, id(object))) -def _perfcheck(object=None): - import time - if object is None: - object = [("string", (1, 2), [3, 4], {5: 6, 7: 8})] * 100000 - p = PrettyPrinter() - t1 = time.perf_counter() - p._safe_repr(object, {}, None, 0, True) - t2 = time.perf_counter() - p.pformat(object) - t3 = time.perf_counter() - print("_safe_repr:", t2 - t1) - print("pformat:", t3 - t2) - def _wrap_bytes_repr(object, width, allowance): current = b'' last = len(object) // 4 * 4 @@ -666,6 +653,3 @@ def _wrap_bytes_repr(object, width, allowance): current = candidate if current: yield repr(current) - -if __name__ == "__main__": - _perfcheck() diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index 02f060ba2c..6644a3cd5c 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -278,13 +278,11 @@ def test_invalid_utf8_arg(self): code = 'import sys, os; s=os.fsencode(sys.argv[1]); print(ascii(s))' # TODO: RUSTPYTHON - @unittest.expectedFailure def run_default(arg): cmd = [sys.executable, '-c', code, arg] return subprocess.run(cmd, stdout=subprocess.PIPE, text=True) # TODO: RUSTPYTHON - @unittest.expectedFailure def run_c_locale(arg): cmd = [sys.executable, '-c', code, arg] env = dict(os.environ) @@ -293,7 +291,6 @@ def run_c_locale(arg): text=True, env=env) # TODO: RUSTPYTHON - @unittest.expectedFailure def run_utf8_mode(arg): cmd = [sys.executable, '-X', 'utf8', '-c', code, arg] return subprocess.run(cmd, stdout=subprocess.PIPE, text=True) @@ -411,7 +408,8 @@ def test_empty_PYTHONPATH_issue16309(self): path = ":".join(sys.path) path = path.encode("ascii", "backslashreplace") sys.stdout.buffer.write(path)""" - rc1, out1, err1 = assert_python_ok('-c', code, PYTHONPATH="") + # TODO: RUSTPYTHON we must unset RUSTPYTHONPATH as well + rc1, out1, err1 = assert_python_ok('-c', code, PYTHONPATH="", RUSTPYTHONPATH="") rc2, out2, err2 = assert_python_ok('-c', code, __isolated=False) # regarding to Posix specification, outputs should be equal # for empty and unset PYTHONPATH diff --git a/Lib/test/test_colorsys.py b/Lib/test/test_colorsys.py index a24e3adcb4..74d76294b0 100644 --- a/Lib/test/test_colorsys.py +++ b/Lib/test/test_colorsys.py @@ -69,6 +69,16 @@ def test_hls_values(self): self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb)) self.assertTripleEqual(rgb, colorsys.hls_to_rgb(*hls)) + def test_hls_nearwhite(self): # gh-106498 + values = ( + # rgb, hls: these do not work in reverse + ((0.9999999999999999, 1, 1), (0.5, 1.0, 1.0)), + ((1, 0.9999999999999999, 0.9999999999999999), (0.0, 1.0, 1.0)), + ) + for rgb, hls in values: + self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb)) + self.assertTripleEqual((1.0, 1.0, 1.0), colorsys.hls_to_rgb(*hls)) + def test_yiq_roundtrip(self): for r in frange(0.0, 1.0, 0.2): for g in frange(0.0, 1.0, 0.2): diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py new file mode 100644 index 0000000000..2646be086c --- /dev/null +++ b/Lib/test/test_csv.py @@ -0,0 +1,1481 @@ +# Copyright (C) 2001,2002 Python Software Foundation +# csv package unit tests + +import copy +import sys +import unittest +from io import StringIO +from tempfile import TemporaryFile +import csv +import gc +import pickle +from test import support +from test.support import warnings_helper, import_helper, check_disallow_instantiation +from itertools import permutations +from textwrap import dedent +from collections import OrderedDict + + +class BadIterable: + def __iter__(self): + raise OSError + + +class Test_Csv(unittest.TestCase): + """ + Test the underlying C csv parser in ways that are not appropriate + from the high level interface. Further tests of this nature are done + in TestDialectRegistry. + """ + def _test_arg_valid(self, ctor, arg): + self.assertRaises(TypeError, ctor) + self.assertRaises(TypeError, ctor, None) + self.assertRaises(TypeError, ctor, arg, bad_attr = 0) + self.assertRaises(TypeError, ctor, arg, delimiter = 0) + self.assertRaises(TypeError, ctor, arg, delimiter = 'XX') + self.assertRaises(csv.Error, ctor, arg, 'foo') + self.assertRaises(TypeError, ctor, arg, delimiter=None) + self.assertRaises(TypeError, ctor, arg, delimiter=1) + self.assertRaises(TypeError, ctor, arg, quotechar=1) + self.assertRaises(TypeError, ctor, arg, lineterminator=None) + self.assertRaises(TypeError, ctor, arg, lineterminator=1) + self.assertRaises(TypeError, ctor, arg, quoting=None) + self.assertRaises(TypeError, ctor, arg, + quoting=csv.QUOTE_ALL, quotechar='') + self.assertRaises(TypeError, ctor, arg, + quoting=csv.QUOTE_ALL, quotechar=None) + self.assertRaises(TypeError, ctor, arg, + quoting=csv.QUOTE_NONE, quotechar='') + + def test_reader_arg_valid(self): + self._test_arg_valid(csv.reader, []) + self.assertRaises(OSError, csv.reader, BadIterable()) + + def test_writer_arg_valid(self): + self._test_arg_valid(csv.writer, StringIO()) + class BadWriter: + @property + def write(self): + raise OSError + self.assertRaises(OSError, csv.writer, BadWriter()) + + def _test_default_attrs(self, ctor, *args): + obj = ctor(*args) + # Check defaults + self.assertEqual(obj.dialect.delimiter, ',') + self.assertIs(obj.dialect.doublequote, True) + self.assertEqual(obj.dialect.escapechar, None) + self.assertEqual(obj.dialect.lineterminator, "\r\n") + self.assertEqual(obj.dialect.quotechar, '"') + self.assertEqual(obj.dialect.quoting, csv.QUOTE_MINIMAL) + self.assertIs(obj.dialect.skipinitialspace, False) + self.assertIs(obj.dialect.strict, False) + # Try deleting or changing attributes (they are read-only) + self.assertRaises(AttributeError, delattr, obj.dialect, 'delimiter') + self.assertRaises(AttributeError, setattr, obj.dialect, 'delimiter', ':') + self.assertRaises(AttributeError, delattr, obj.dialect, 'quoting') + self.assertRaises(AttributeError, setattr, obj.dialect, + 'quoting', None) + + def test_reader_attrs(self): + self._test_default_attrs(csv.reader, []) + + def test_writer_attrs(self): + self._test_default_attrs(csv.writer, StringIO()) + + def _test_kw_attrs(self, ctor, *args): + # Now try with alternate options + kwargs = dict(delimiter=':', doublequote=False, escapechar='\\', + lineterminator='\r', quotechar='*', + quoting=csv.QUOTE_NONE, skipinitialspace=True, + strict=True) + obj = ctor(*args, **kwargs) + self.assertEqual(obj.dialect.delimiter, ':') + self.assertIs(obj.dialect.doublequote, False) + self.assertEqual(obj.dialect.escapechar, '\\') + self.assertEqual(obj.dialect.lineterminator, "\r") + self.assertEqual(obj.dialect.quotechar, '*') + self.assertEqual(obj.dialect.quoting, csv.QUOTE_NONE) + self.assertIs(obj.dialect.skipinitialspace, True) + self.assertIs(obj.dialect.strict, True) + + def test_reader_kw_attrs(self): + self._test_kw_attrs(csv.reader, []) + + def test_writer_kw_attrs(self): + self._test_kw_attrs(csv.writer, StringIO()) + + def _test_dialect_attrs(self, ctor, *args): + # Now try with dialect-derived options + class dialect: + delimiter='-' + doublequote=False + escapechar='^' + lineterminator='$' + quotechar='#' + quoting=csv.QUOTE_ALL + skipinitialspace=True + strict=False + args = args + (dialect,) + obj = ctor(*args) + self.assertEqual(obj.dialect.delimiter, '-') + self.assertIs(obj.dialect.doublequote, False) + self.assertEqual(obj.dialect.escapechar, '^') + self.assertEqual(obj.dialect.lineterminator, "$") + self.assertEqual(obj.dialect.quotechar, '#') + self.assertEqual(obj.dialect.quoting, csv.QUOTE_ALL) + self.assertIs(obj.dialect.skipinitialspace, True) + self.assertIs(obj.dialect.strict, False) + + def test_reader_dialect_attrs(self): + self._test_dialect_attrs(csv.reader, []) + + def test_writer_dialect_attrs(self): + self._test_dialect_attrs(csv.writer, StringIO()) + + + def _write_test(self, fields, expect, **kwargs): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, **kwargs) + writer.writerow(fields) + fileobj.seek(0) + self.assertEqual(fileobj.read(), + expect + writer.dialect.lineterminator) + + def _write_error_test(self, exc, fields, **kwargs): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, **kwargs) + with self.assertRaises(exc): + writer.writerow(fields) + fileobj.seek(0) + self.assertEqual(fileobj.read(), '') + + # TODO: RUSTPYTHON ''\r\n to ""\r\n unsupported + @unittest.expectedFailure + def test_write_arg_valid(self): + self._write_error_test(csv.Error, None) + self._write_test((), '') + self._write_test([None], '""') + self._write_error_test(csv.Error, [None], quoting = csv.QUOTE_NONE) + # Check that exceptions are passed up the chain + self._write_error_test(OSError, BadIterable()) + class BadList: + def __len__(self): + return 10 + def __getitem__(self, i): + if i > 2: + raise OSError + self._write_error_test(OSError, BadList()) + class BadItem: + def __str__(self): + raise OSError + self._write_error_test(OSError, [BadItem()]) + + def test_write_bigfield(self): + # This exercises the buffer realloc functionality + bigstring = 'X' * 50000 + self._write_test([bigstring,bigstring], '%s,%s' % \ + (bigstring, bigstring)) + + # TODO: RUSTPYTHON quoting style check is unsupported + @unittest.expectedFailure + def test_write_quoting(self): + self._write_test(['a',1,'p,q'], 'a,1,"p,q"') + self._write_error_test(csv.Error, ['a',1,'p,q'], + quoting = csv.QUOTE_NONE) + self._write_test(['a',1,'p,q'], 'a,1,"p,q"', + quoting = csv.QUOTE_MINIMAL) + self._write_test(['a',1,'p,q'], '"a",1,"p,q"', + quoting = csv.QUOTE_NONNUMERIC) + self._write_test(['a',1,'p,q'], '"a","1","p,q"', + quoting = csv.QUOTE_ALL) + self._write_test(['a\nb',1], '"a\nb","1"', + quoting = csv.QUOTE_ALL) + self._write_test(['a','',None,1], '"a","",,1', + quoting = csv.QUOTE_STRINGS) + self._write_test(['a','',None,1], '"a","",,"1"', + quoting = csv.QUOTE_NOTNULL) + + # TODO: RUSTPYTHON doublequote check is unsupported + @unittest.expectedFailure + def test_write_escape(self): + self._write_test(['a',1,'p,q'], 'a,1,"p,q"', + escapechar='\\') + self._write_error_test(csv.Error, ['a',1,'p,"q"'], + escapechar=None, doublequote=False) + self._write_test(['a',1,'p,"q"'], 'a,1,"p,\\"q\\""', + escapechar='\\', doublequote = False) + self._write_test(['"'], '""""', + escapechar='\\', quoting = csv.QUOTE_MINIMAL) + self._write_test(['"'], '\\"', + escapechar='\\', quoting = csv.QUOTE_MINIMAL, + doublequote = False) + self._write_test(['"'], '\\"', + escapechar='\\', quoting = csv.QUOTE_NONE) + self._write_test(['a',1,'p,q'], 'a,1,p\\,q', + escapechar='\\', quoting = csv.QUOTE_NONE) + self._write_test(['\\', 'a'], '\\\\,a', + escapechar='\\', quoting=csv.QUOTE_NONE) + self._write_test(['\\', 'a'], '\\\\,a', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + self._write_test(['\\', 'a'], '"\\\\","a"', + escapechar='\\', quoting=csv.QUOTE_ALL) + self._write_test(['\\ ', 'a'], '\\\\ ,a', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + self._write_test(['\\,', 'a'], '\\\\\\,,a', + escapechar='\\', quoting=csv.QUOTE_NONE) + self._write_test([',\\', 'a'], '",\\\\",a', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + self._write_test(['C\\', '6', '7', 'X"'], 'C\\\\,6,7,"X"""', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + + # TODO: RUSTPYTHON lineterminator double char unsupported + @unittest.expectedFailure + def test_write_lineterminator(self): + for lineterminator in '\r\n', '\n', '\r', '!@#', '\0': + with self.subTest(lineterminator=lineterminator): + with StringIO() as sio: + writer = csv.writer(sio, lineterminator=lineterminator) + writer.writerow(['a', 'b']) + writer.writerow([1, 2]) + self.assertEqual(sio.getvalue(), + f'a,b{lineterminator}' + f'1,2{lineterminator}') + + # TODO: RUSTPYTHON ''\r\n to ""\r\n unspported + @unittest.expectedFailure + def test_write_iterable(self): + self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"') + self._write_test(iter(['a', 1, None]), 'a,1,') + self._write_test(iter([]), '') + self._write_test(iter([None]), '""') + self._write_error_test(csv.Error, iter([None]), quoting=csv.QUOTE_NONE) + self._write_test(iter([None, None]), ',') + + def test_writerows(self): + class BrokenFile: + def write(self, buf): + raise OSError + writer = csv.writer(BrokenFile()) + self.assertRaises(OSError, writer.writerows, [['a']]) + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + self.assertRaises(TypeError, writer.writerows, None) + writer.writerows([['a', 'b'], ['c', 'd']]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "a,b\r\nc,d\r\n") + + def test_writerows_with_none(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([['a', None], [None, 'd']]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "a,\r\n,d\r\n") + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([[None], ['a']]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), '""\r\na\r\n') + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([['a'], [None]]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), 'a\r\n""\r\n') + + def test_writerows_errors(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + self.assertRaises(TypeError, writer.writerows, None) + self.assertRaises(OSError, writer.writerows, BadIterable()) + + @support.cpython_only + @support.requires_legacy_unicode_capi() + @warnings_helper.ignore_warnings(category=DeprecationWarning) + def test_writerows_legacy_strings(self): + import _testcapi + c = _testcapi.unicode_legacy_string('a') + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([[c]]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "a\r\n") + + def _read_test(self, input, expect, **kwargs): + reader = csv.reader(input, **kwargs) + result = list(reader) + self.assertEqual(result, expect) + + # TODO RUSTPYTHON strict mode is unsupported + @unittest.expectedFailure + def test_read_oddinputs(self): + self._read_test([], []) + self._read_test([''], [[]]) + self.assertRaises(csv.Error, self._read_test, + ['"ab"c'], None, strict = 1) + self._read_test(['"ab"c'], [['abc']], doublequote = 0) + + self.assertRaises(csv.Error, self._read_test, + [b'abc'], None) + + def test_read_eol(self): + self._read_test(['a,b'], [['a','b']]) + self._read_test(['a,b\n'], [['a','b']]) + self._read_test(['a,b\r\n'], [['a','b']]) + self._read_test(['a,b\r'], [['a','b']]) + self.assertRaises(csv.Error, self._read_test, ['a,b\rc,d'], []) + self.assertRaises(csv.Error, self._read_test, ['a,b\nc,d'], []) + self.assertRaises(csv.Error, self._read_test, ['a,b\r\nc,d'], []) + + # TODO RUSTPYTHON double quote umimplement + @unittest.expectedFailure + def test_read_eof(self): + self._read_test(['a,"'], [['a', '']]) + self._read_test(['"a'], [['a']]) + self._read_test(['^'], [['\n']], escapechar='^') + self.assertRaises(csv.Error, self._read_test, ['a,"'], [], strict=True) + self.assertRaises(csv.Error, self._read_test, ['"a'], [], strict=True) + self.assertRaises(csv.Error, self._read_test, + ['^'], [], escapechar='^', strict=True) + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_read_nul(self): + self._read_test(['\0'], [['\0']]) + self._read_test(['a,\0b,c'], [['a', '\0b', 'c']]) + self._read_test(['a,b\0,c'], [['a', 'b\0', 'c']]) + self._read_test(['a,b\\\0,c'], [['a', 'b\0', 'c']], escapechar='\\') + self._read_test(['a,"\0b",c'], [['a', '\0b', 'c']]) + + def test_read_delimiter(self): + self._read_test(['a,b,c'], [['a', 'b', 'c']]) + self._read_test(['a;b;c'], [['a', 'b', 'c']], delimiter=';') + self._read_test(['a\0b\0c'], [['a', 'b', 'c']], delimiter='\0') + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_read_escape(self): + self._read_test(['a,\\b,c'], [['a', 'b', 'c']], escapechar='\\') + self._read_test(['a,b\\,c'], [['a', 'b,c']], escapechar='\\') + self._read_test(['a,"b\\,c"'], [['a', 'b,c']], escapechar='\\') + self._read_test(['a,"b,\\c"'], [['a', 'b,c']], escapechar='\\') + self._read_test(['a,"b,c\\""'], [['a', 'b,c"']], escapechar='\\') + self._read_test(['a,"b,c"\\'], [['a', 'b,c\\']], escapechar='\\') + self._read_test(['a,^b,c'], [['a', 'b', 'c']], escapechar='^') + self._read_test(['a,\0b,c'], [['a', 'b', 'c']], escapechar='\0') + self._read_test(['a,\\b,c'], [['a', '\\b', 'c']], escapechar=None) + self._read_test(['a,\\b,c'], [['a', '\\b', 'c']]) + + # TODO RUSTPYTHON escapechar unsupported + @unittest.expectedFailure + def test_read_quoting(self): + self._read_test(['1,",3,",5'], [['1', ',3,', '5']]) + self._read_test(['1,",3,",5'], [['1', '"', '3', '"', '5']], + quotechar=None, escapechar='\\') + self._read_test(['1,",3,",5'], [['1', '"', '3', '"', '5']], + quoting=csv.QUOTE_NONE, escapechar='\\') + # will this fail where locale uses comma for decimals? + self._read_test([',3,"5",7.3, 9'], [['', 3, '5', 7.3, 9]], + quoting=csv.QUOTE_NONNUMERIC) + self._read_test(['"a\nb", 7'], [['a\nb', ' 7']]) + self.assertRaises(ValueError, self._read_test, + ['abc,3'], [[]], + quoting=csv.QUOTE_NONNUMERIC) + self._read_test(['1,@,3,@,5'], [['1', ',3,', '5']], quotechar='@') + self._read_test(['1,\0,3,\0,5'], [['1', ',3,', '5']], quotechar='\0') + + def test_read_skipinitialspace(self): + self._read_test(['no space, space, spaces,\ttab'], + [['no space', 'space', 'spaces', '\ttab']], + skipinitialspace=True) + + def test_read_bigfield(self): + # This exercises the buffer realloc functionality and field size + # limits. + limit = csv.field_size_limit() + try: + size = 50000 + bigstring = 'X' * size + bigline = '%s,%s' % (bigstring, bigstring) + self._read_test([bigline], [[bigstring, bigstring]]) + csv.field_size_limit(size) + self._read_test([bigline], [[bigstring, bigstring]]) + self.assertEqual(csv.field_size_limit(), size) + csv.field_size_limit(size-1) + self.assertRaises(csv.Error, self._read_test, [bigline], []) + self.assertRaises(TypeError, csv.field_size_limit, None) + self.assertRaises(TypeError, csv.field_size_limit, 1, None) + finally: + csv.field_size_limit(limit) + + def test_read_linenum(self): + r = csv.reader(['line,1', 'line,2', 'line,3']) + self.assertEqual(r.line_num, 0) + next(r) + self.assertEqual(r.line_num, 1) + next(r) + self.assertEqual(r.line_num, 2) + next(r) + self.assertEqual(r.line_num, 3) + self.assertRaises(StopIteration, next, r) + self.assertEqual(r.line_num, 3) + + # TODO: RUSTPYTHON only '\r\n' unsupported + @unittest.expectedFailure + def test_roundtrip_quoteed_newlines(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + rows = [['a\nb','b'],['c','x\r\nd']] + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj)): + self.assertEqual(row, rows[i]) + + # TODO: RUSTPYTHON only '\r\n' unsupported + @unittest.expectedFailure + def test_roundtrip_escaped_unquoted_newlines(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\") + rows = [['a\nb','b'],['c','x\r\nd']] + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")): + self.assertEqual(row,rows[i]) + +class TestDialectRegistry(unittest.TestCase): + def test_registry_badargs(self): + self.assertRaises(TypeError, csv.list_dialects, None) + self.assertRaises(TypeError, csv.get_dialect) + self.assertRaises(csv.Error, csv.get_dialect, None) + self.assertRaises(csv.Error, csv.get_dialect, "nonesuch") + self.assertRaises(TypeError, csv.unregister_dialect) + self.assertRaises(csv.Error, csv.unregister_dialect, None) + self.assertRaises(csv.Error, csv.unregister_dialect, "nonesuch") + self.assertRaises(TypeError, csv.register_dialect, None) + self.assertRaises(TypeError, csv.register_dialect, None, None) + self.assertRaises(TypeError, csv.register_dialect, "nonesuch", 0, 0) + self.assertRaises(TypeError, csv.register_dialect, "nonesuch", + badargument=None) + self.assertRaises(TypeError, csv.register_dialect, "nonesuch", + quoting=None) + self.assertRaises(TypeError, csv.register_dialect, []) + + def test_registry(self): + class myexceltsv(csv.excel): + delimiter = "\t" + name = "myexceltsv" + expected_dialects = csv.list_dialects() + [name] + expected_dialects.sort() + csv.register_dialect(name, myexceltsv) + self.addCleanup(csv.unregister_dialect, name) + self.assertEqual(csv.get_dialect(name).delimiter, '\t') + got_dialects = sorted(csv.list_dialects()) + self.assertEqual(expected_dialects, got_dialects) + + def test_register_kwargs(self): + name = 'fedcba' + csv.register_dialect(name, delimiter=';') + self.addCleanup(csv.unregister_dialect, name) + self.assertEqual(csv.get_dialect(name).delimiter, ';') + self.assertEqual([['X', 'Y', 'Z']], list(csv.reader(['X;Y;Z'], name))) + + def test_register_kwargs_override(self): + class mydialect(csv.Dialect): + delimiter = "\t" + quotechar = '"' + doublequote = True + skipinitialspace = False + lineterminator = '\r\n' + quoting = csv.QUOTE_MINIMAL + + name = 'test_dialect' + csv.register_dialect(name, mydialect, + delimiter=';', + quotechar="'", + doublequote=False, + skipinitialspace=True, + lineterminator='\n', + quoting=csv.QUOTE_ALL) + self.addCleanup(csv.unregister_dialect, name) + + # Ensure that kwargs do override attributes of a dialect class: + dialect = csv.get_dialect(name) + self.assertEqual(dialect.delimiter, ';') + self.assertEqual(dialect.quotechar, "'") + self.assertEqual(dialect.doublequote, False) + self.assertEqual(dialect.skipinitialspace, True) + self.assertEqual(dialect.lineterminator, '\n') + self.assertEqual(dialect.quoting, csv.QUOTE_ALL) + + def test_incomplete_dialect(self): + class myexceltsv(csv.Dialect): + delimiter = "\t" + self.assertRaises(csv.Error, myexceltsv) + + def test_space_dialect(self): + class space(csv.excel): + delimiter = " " + quoting = csv.QUOTE_NONE + escapechar = "\\" + + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("abc def\nc1ccccc1 benzene\n") + fileobj.seek(0) + reader = csv.reader(fileobj, dialect=space()) + self.assertEqual(next(reader), ["abc", "def"]) + self.assertEqual(next(reader), ["c1ccccc1", "benzene"]) + + def compare_dialect_123(self, expected, *writeargs, **kwwriteargs): + + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + + writer = csv.writer(fileobj, *writeargs, **kwwriteargs) + writer.writerow([1,2,3]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_dialect_apply(self): + class testA(csv.excel): + delimiter = "\t" + class testB(csv.excel): + delimiter = ":" + class testC(csv.excel): + delimiter = "|" + class testUni(csv.excel): + delimiter = "\u039B" + + class unspecified(): + # A class to pass as dialect but with no dialect attributes. + pass + + csv.register_dialect('testC', testC) + try: + self.compare_dialect_123("1,2,3\r\n") + self.compare_dialect_123("1,2,3\r\n", dialect=None) + self.compare_dialect_123("1,2,3\r\n", dialect=unspecified) + self.compare_dialect_123("1\t2\t3\r\n", testA) + self.compare_dialect_123("1:2:3\r\n", dialect=testB()) + self.compare_dialect_123("1|2|3\r\n", dialect='testC') + self.compare_dialect_123("1;2;3\r\n", dialect=testA, + delimiter=';') + self.compare_dialect_123("1\u039B2\u039B3\r\n", + dialect=testUni) + + finally: + csv.unregister_dialect('testC') + + def test_bad_dialect(self): + # Unknown parameter + self.assertRaises(TypeError, csv.reader, [], bad_attr = 0) + # Bad values + self.assertRaises(TypeError, csv.reader, [], delimiter = None) + self.assertRaises(TypeError, csv.reader, [], quoting = -1) + self.assertRaises(TypeError, csv.reader, [], quoting = 100) + + def test_copy(self): + for name in csv.list_dialects(): + dialect = csv.get_dialect(name) + self.assertRaises(TypeError, copy.copy, dialect) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_pickle(self): + for name in csv.list_dialects(): + dialect = csv.get_dialect(name) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + self.assertRaises(TypeError, pickle.dumps, dialect, proto) + +class TestCsvBase(unittest.TestCase): + def readerAssertEqual(self, input, expected_result): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + fileobj.write(input) + fileobj.seek(0) + reader = csv.reader(fileobj, dialect = self.dialect) + fields = list(reader) + self.assertEqual(fields, expected_result) + + def writerAssertEqual(self, input, expected_result): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect = self.dialect) + writer.writerows(input) + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected_result) + +class TestDialectExcel(TestCsvBase): + dialect = 'excel' + + def test_single(self): + self.readerAssertEqual('abc', [['abc']]) + + def test_simple(self): + self.readerAssertEqual('1,2,3,4,5', [['1','2','3','4','5']]) + + def test_blankline(self): + self.readerAssertEqual('', []) + + def test_empty_fields(self): + self.readerAssertEqual(',', [['', '']]) + + def test_singlequoted(self): + self.readerAssertEqual('""', [['']]) + + def test_singlequoted_left_empty(self): + self.readerAssertEqual('"",', [['','']]) + + def test_singlequoted_right_empty(self): + self.readerAssertEqual(',""', [['','']]) + + def test_single_quoted_quote(self): + self.readerAssertEqual('""""', [['"']]) + + def test_quoted_quotes(self): + self.readerAssertEqual('""""""', [['""']]) + + def test_inline_quote(self): + self.readerAssertEqual('a""b', [['a""b']]) + + def test_inline_quotes(self): + self.readerAssertEqual('a"b"c', [['a"b"c']]) + + def test_quotes_and_more(self): + # Excel would never write a field containing '"a"b', but when + # reading one, it will return 'ab'. + self.readerAssertEqual('"a"b', [['ab']]) + + def test_lone_quote(self): + self.readerAssertEqual('a"b', [['a"b']]) + + def test_quote_and_quote(self): + # Excel would never write a field containing '"a" "b"', but when + # reading one, it will return 'a "b"'. + self.readerAssertEqual('"a" "b"', [['a "b"']]) + + def test_space_and_quote(self): + self.readerAssertEqual(' "a"', [[' "a"']]) + + def test_quoted(self): + self.readerAssertEqual('1,2,3,"I think, therefore I am",5,6', + [['1', '2', '3', + 'I think, therefore I am', + '5', '6']]) + + def test_quoted_quote(self): + self.readerAssertEqual('1,2,3,"""I see,"" said the blind man","as he picked up his hammer and saw"', + [['1', '2', '3', + '"I see," said the blind man', + 'as he picked up his hammer and saw']]) + + # Rustpython TODO + @unittest.expectedFailure + def test_quoted_nl(self): + input = '''\ +1,2,3,"""I see,"" +said the blind man","as he picked up his +hammer and saw" +9,8,7,6''' + self.readerAssertEqual(input, + [['1', '2', '3', + '"I see,"\nsaid the blind man', + 'as he picked up his\nhammer and saw'], + ['9','8','7','6']]) + + def test_dubious_quote(self): + self.readerAssertEqual('12,12,1",', [['12', '12', '1"', '']]) + + def test_null(self): + self.writerAssertEqual([], '') + + def test_single_writer(self): + self.writerAssertEqual([['abc']], 'abc\r\n') + + def test_simple_writer(self): + self.writerAssertEqual([[1, 2, 'abc', 3, 4]], '1,2,abc,3,4\r\n') + + def test_quotes(self): + self.writerAssertEqual([[1, 2, 'a"bc"', 3, 4]], '1,2,"a""bc""",3,4\r\n') + + def test_quote_fieldsep(self): + self.writerAssertEqual([['abc,def']], '"abc,def"\r\n') + + def test_newlines(self): + self.writerAssertEqual([[1, 2, 'a\nbc', 3, 4]], '1,2,"a\nbc",3,4\r\n') + +class EscapedExcel(csv.excel): + quoting = csv.QUOTE_NONE + escapechar = '\\' + +class TestEscapedExcel(TestCsvBase): + dialect = EscapedExcel() + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_escape_fieldsep(self): + self.writerAssertEqual([['abc,def']], 'abc\\,def\r\n') + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_read_escape_fieldsep(self): + self.readerAssertEqual('abc\\,def\r\n', [['abc,def']]) + +class TestDialectUnix(TestCsvBase): + dialect = 'unix' + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_simple_writer(self): + self.writerAssertEqual([[1, 'abc def', 'abc']], '"1","abc def","abc"\n') + + def test_simple_reader(self): + self.readerAssertEqual('"1","abc def","abc"\n', [['1', 'abc def', 'abc']]) + +class QuotedEscapedExcel(csv.excel): + quoting = csv.QUOTE_NONNUMERIC + escapechar = '\\' + +class TestQuotedEscapedExcel(TestCsvBase): + dialect = QuotedEscapedExcel() + + def test_write_escape_fieldsep(self): + self.writerAssertEqual([['abc,def']], '"abc,def"\r\n') + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_read_escape_fieldsep(self): + self.readerAssertEqual('"abc\\,def"\r\n', [['abc,def']]) + +class TestDictFields(unittest.TestCase): + ### "long" means the row is longer than the number of fieldnames + ### "short" means there are fewer elements in the row than fieldnames + def test_writeheader_return_value(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) + writeheader_return_value = writer.writeheader() + self.assertEqual(writeheader_return_value, 10) + + def test_write_simple_dict(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) + writer.writeheader() + fileobj.seek(0) + self.assertEqual(fileobj.readline(), "f1,f2,f3\r\n") + writer.writerow({"f1": 10, "f3": "abc"}) + fileobj.seek(0) + fileobj.readline() # header + self.assertEqual(fileobj.read(), "10,,abc\r\n") + + def test_write_multiple_dict_rows(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, fieldnames=["f1", "f2", "f3"]) + writer.writeheader() + self.assertEqual(fileobj.getvalue(), "f1,f2,f3\r\n") + writer.writerows([{"f1": 1, "f2": "abc", "f3": "f"}, + {"f1": 2, "f2": 5, "f3": "xyz"}]) + self.assertEqual(fileobj.getvalue(), + "f1,f2,f3\r\n1,abc,f\r\n2,5,xyz\r\n") + + def test_write_no_fields(self): + fileobj = StringIO() + self.assertRaises(TypeError, csv.DictWriter, fileobj) + + def test_write_fields_not_in_fieldnames(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) + # Of special note is the non-string key (issue 19449) + with self.assertRaises(ValueError) as cx: + writer.writerow({"f4": 10, "f2": "spam", 1: "abc"}) + exception = str(cx.exception) + self.assertIn("fieldnames", exception) + self.assertIn("'f4'", exception) + self.assertNotIn("'f2'", exception) + self.assertIn("1", exception) + + def test_typo_in_extrasaction_raises_error(self): + fileobj = StringIO() + self.assertRaises(ValueError, csv.DictWriter, fileobj, ['f1', 'f2'], + extrasaction="raised") + + def test_write_field_not_in_field_names_raise(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="raise") + dictrow = {'f0': 0, 'f1': 1, 'f2': 2, 'f3': 3} + self.assertRaises(ValueError, csv.DictWriter.writerow, writer, dictrow) + + # see bpo-44512 (differently cased 'raise' should not result in 'ignore') + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="RAISE") + self.assertRaises(ValueError, csv.DictWriter.writerow, writer, dictrow) + + def test_write_field_not_in_field_names_ignore(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="ignore") + dictrow = {'f0': 0, 'f1': 1, 'f2': 2, 'f3': 3} + csv.DictWriter.writerow(writer, dictrow) + self.assertEqual(fileobj.getvalue(), "1,2\r\n") + + # bpo-44512 + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="IGNORE") + csv.DictWriter.writerow(writer, dictrow) + + def test_dict_reader_fieldnames_accepts_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, iter(fieldnames)) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, fieldnames) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_rejects_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, iter(fieldnames)) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, fieldnames) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_is_optional(self): + f = StringIO() + reader = csv.DictReader(f, fieldnames=None) + + def test_read_dict_fields(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=["f1", "f2", "f3"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_dict_no_fieldnames(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + + # Two test cases to make sure existing ways of implicitly setting + # fieldnames continue to work. Both arise from discussion in issue3436. + def test_read_dict_fieldnames_from_file(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=next(csv.reader(fileobj))) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_dict_fieldnames_chain(self): + import itertools + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj) + first = next(reader) + for row in itertools.chain([first], reader): + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + self.assertEqual(row, {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_long(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=["f1", "f2"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', + None: ["abc", "4", "5", "6"]}) + + def test_read_long_with_rest(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=["f1", "f2"], restkey="_rest") + self.assertEqual(next(reader), {"f1": '1', "f2": '2', + "_rest": ["abc", "4", "5", "6"]}) + + def test_read_long_with_rest_no_fieldnames(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, restkey="_rest") + self.assertEqual(reader.fieldnames, ["f1", "f2"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', + "_rest": ["abc", "4", "5", "6"]}) + + def test_read_short(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc,4,5,6\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames="1 2 3 4 5 6".split(), + restval="DEFAULT") + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": 'DEFAULT', "5": 'DEFAULT', + "6": 'DEFAULT'}) + + def test_read_multi(self): + sample = [ + '2147483648,43.0e12,17,abc,def\r\n', + '147483648,43.0e2,17,abc,def\r\n', + '47483648,43.0,170,abc,def\r\n' + ] + + reader = csv.DictReader(sample, + fieldnames="i1 float i2 s1 s2".split()) + self.assertEqual(next(reader), {"i1": '2147483648', + "float": '43.0e12', + "i2": '17', + "s1": 'abc', + "s2": 'def'}) + + # TODO RUSTPYTHON + @unittest.expectedFailure + def test_read_with_blanks(self): + reader = csv.DictReader(["1,2,abc,4,5,6\r\n","\r\n", + "1,2,abc,4,5,6\r\n"], + fieldnames="1 2 3 4 5 6".split()) + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + + def test_read_semi_sep(self): + reader = csv.DictReader(["1;2;abc;4;5;6\r\n"], + fieldnames="1 2 3 4 5 6".split(), + delimiter=';') + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + +class TestArrayWrites(unittest.TestCase): + def test_int_write(self): + import array + contents = [(20-i) for i in range(20)] + a = array.array('i', contents) + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join([str(i) for i in a])+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_double_write(self): + import array + contents = [(20-i)*0.1 for i in range(20)] + a = array.array('d', contents) + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join([str(i) for i in a])+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_float_write(self): + import array + contents = [(20-i)*0.1 for i in range(20)] + a = array.array('f', contents) + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join([str(i) for i in a])+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_char_write(self): + import array, string + a = array.array('u', string.ascii_letters) + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join(a)+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + +class TestDialectValidity(unittest.TestCase): + def test_quoting(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.quoting, csv.QUOTE_NONE) + + mydialect.quoting = None + self.assertRaises(csv.Error, mydialect) + + mydialect.doublequote = True + mydialect.quoting = csv.QUOTE_ALL + mydialect.quotechar = '"' + d = mydialect() + self.assertEqual(d.quoting, csv.QUOTE_ALL) + self.assertEqual(d.quotechar, '"') + self.assertTrue(d.doublequote) + + mydialect.quotechar = "" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"quotechar" must be a 1-character string') + + mydialect.quotechar = "''" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"quotechar" must be a 1-character string') + + mydialect.quotechar = 4 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"quotechar" must be string or None, not int') + + def test_delimiter(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.delimiter, ";") + + mydialect.delimiter = ":::" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be a 1-character string') + + mydialect.delimiter = "" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be a 1-character string') + + mydialect.delimiter = b"," + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be string, not bytes') + + mydialect.delimiter = 4 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be string, not int') + + mydialect.delimiter = None + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be string, not NoneType') + + def test_escapechar(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.escapechar, "\\") + + mydialect.escapechar = "" + with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'): + mydialect() + + mydialect.escapechar = "**" + with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'): + mydialect() + + mydialect.escapechar = b"*" + with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not bytes'): + mydialect() + + mydialect.escapechar = 4 + with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not int'): + mydialect() + + def test_lineterminator(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.lineterminator, '\r\n') + + mydialect.lineterminator = ":::" + d = mydialect() + self.assertEqual(d.lineterminator, ":::") + + mydialect.lineterminator = 4 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"lineterminator" must be a string') + + def test_invalid_chars(self): + def create_invalid(field_name, value): + class mydialect(csv.Dialect): + pass + setattr(mydialect, field_name, value) + d = mydialect() + + for field_name in ("delimiter", "escapechar", "quotechar"): + with self.subTest(field_name=field_name): + self.assertRaises(csv.Error, create_invalid, field_name, "") + self.assertRaises(csv.Error, create_invalid, field_name, "abc") + self.assertRaises(csv.Error, create_invalid, field_name, b'x') + self.assertRaises(csv.Error, create_invalid, field_name, 5) + + +class TestSniffer(unittest.TestCase): + sample1 = """\ +Harry's, Arlington Heights, IL, 2/1/03, Kimi Hayes +Shark City, Glendale Heights, IL, 12/28/02, Prezence +Tommy's Place, Blue Island, IL, 12/28/02, Blue Sunday/White Crow +Stonecutters Seafood and Chop House, Lemont, IL, 12/19/02, Week Back +""" + sample2 = """\ +'Harry''s':'Arlington Heights':'IL':'2/1/03':'Kimi Hayes' +'Shark City':'Glendale Heights':'IL':'12/28/02':'Prezence' +'Tommy''s Place':'Blue Island':'IL':'12/28/02':'Blue Sunday/White Crow' +'Stonecutters ''Seafood'' and Chop House':'Lemont':'IL':'12/19/02':'Week Back' +""" + header1 = '''\ +"venue","city","state","date","performers" +''' + sample3 = '''\ +05/05/03?05/05/03?05/05/03?05/05/03?05/05/03?05/05/03 +05/05/03?05/05/03?05/05/03?05/05/03?05/05/03?05/05/03 +05/05/03?05/05/03?05/05/03?05/05/03?05/05/03?05/05/03 +''' + + sample4 = '''\ +2147483648;43.0e12;17;abc;def +147483648;43.0e2;17;abc;def +47483648;43.0;170;abc;def +''' + + sample5 = "aaa\tbbb\r\nAAA\t\r\nBBB\t\r\n" + sample6 = "a|b|c\r\nd|e|f\r\n" + sample7 = "'a'|'b'|'c'\r\n'd'|e|f\r\n" + +# Issue 18155: Use a delimiter that is a special char to regex: + + header2 = '''\ +"venue"+"city"+"state"+"date"+"performers" +''' + sample8 = """\ +Harry's+ Arlington Heights+ IL+ 2/1/03+ Kimi Hayes +Shark City+ Glendale Heights+ IL+ 12/28/02+ Prezence +Tommy's Place+ Blue Island+ IL+ 12/28/02+ Blue Sunday/White Crow +Stonecutters Seafood and Chop House+ Lemont+ IL+ 12/19/02+ Week Back +""" + sample9 = """\ +'Harry''s'+ Arlington Heights'+ 'IL'+ '2/1/03'+ 'Kimi Hayes' +'Shark City'+ Glendale Heights'+' IL'+ '12/28/02'+ 'Prezence' +'Tommy''s Place'+ Blue Island'+ 'IL'+ '12/28/02'+ 'Blue Sunday/White Crow' +'Stonecutters ''Seafood'' and Chop House'+ 'Lemont'+ 'IL'+ '12/19/02'+ 'Week Back' +""" + + sample10 = dedent(""" + abc,def + ghijkl,mno + ghi,jkl + """) + + sample11 = dedent(""" + abc,def + ghijkl,mnop + ghi,jkl + """) + + sample12 = dedent(""""time","forces" + 1,1.5 + 0.5,5+0j + 0,0 + 1+1j,6 + """) + + sample13 = dedent(""""time","forces" + 0,0 + 1,2 + a,b + """) + + sample14 = """\ +abc\0def +ghijkl\0mno +ghi\0jkl +""" + + def test_issue43625(self): + sniffer = csv.Sniffer() + self.assertTrue(sniffer.has_header(self.sample12)) + self.assertFalse(sniffer.has_header(self.sample13)) + + def test_has_header_strings(self): + "More to document existing (unexpected?) behavior than anything else." + sniffer = csv.Sniffer() + self.assertFalse(sniffer.has_header(self.sample10)) + self.assertFalse(sniffer.has_header(self.sample11)) + + def test_has_header(self): + sniffer = csv.Sniffer() + self.assertIs(sniffer.has_header(self.sample1), False) + self.assertIs(sniffer.has_header(self.header1 + self.sample1), True) + + def test_has_header_regex_special_delimiter(self): + sniffer = csv.Sniffer() + self.assertIs(sniffer.has_header(self.sample8), False) + self.assertIs(sniffer.has_header(self.header2 + self.sample8), True) + + def test_guess_quote_and_delimiter(self): + sniffer = csv.Sniffer() + for header in (";'123;4';", "'123;4';", ";'123;4'", "'123;4'"): + with self.subTest(header): + dialect = sniffer.sniff(header, ",;") + self.assertEqual(dialect.delimiter, ';') + self.assertEqual(dialect.quotechar, "'") + self.assertIs(dialect.doublequote, False) + self.assertIs(dialect.skipinitialspace, False) + + def test_sniff(self): + sniffer = csv.Sniffer() + dialect = sniffer.sniff(self.sample1) + self.assertEqual(dialect.delimiter, ",") + self.assertEqual(dialect.quotechar, '"') + self.assertIs(dialect.skipinitialspace, True) + + dialect = sniffer.sniff(self.sample2) + self.assertEqual(dialect.delimiter, ":") + self.assertEqual(dialect.quotechar, "'") + self.assertIs(dialect.skipinitialspace, False) + + def test_delimiters(self): + sniffer = csv.Sniffer() + dialect = sniffer.sniff(self.sample3) + # given that all three lines in sample3 are equal, + # I think that any character could have been 'guessed' as the + # delimiter, depending on dictionary order + self.assertIn(dialect.delimiter, self.sample3) + dialect = sniffer.sniff(self.sample3, delimiters="?,") + self.assertEqual(dialect.delimiter, "?") + dialect = sniffer.sniff(self.sample3, delimiters="/,") + self.assertEqual(dialect.delimiter, "/") + dialect = sniffer.sniff(self.sample4) + self.assertEqual(dialect.delimiter, ";") + dialect = sniffer.sniff(self.sample5) + self.assertEqual(dialect.delimiter, "\t") + dialect = sniffer.sniff(self.sample6) + self.assertEqual(dialect.delimiter, "|") + dialect = sniffer.sniff(self.sample7) + self.assertEqual(dialect.delimiter, "|") + self.assertEqual(dialect.quotechar, "'") + dialect = sniffer.sniff(self.sample8) + self.assertEqual(dialect.delimiter, '+') + dialect = sniffer.sniff(self.sample9) + self.assertEqual(dialect.delimiter, '+') + self.assertEqual(dialect.quotechar, "'") + dialect = sniffer.sniff(self.sample14) + self.assertEqual(dialect.delimiter, '\0') + + def test_doublequote(self): + sniffer = csv.Sniffer() + dialect = sniffer.sniff(self.header1) + self.assertFalse(dialect.doublequote) + dialect = sniffer.sniff(self.header2) + self.assertFalse(dialect.doublequote) + dialect = sniffer.sniff(self.sample2) + self.assertTrue(dialect.doublequote) + dialect = sniffer.sniff(self.sample8) + self.assertFalse(dialect.doublequote) + dialect = sniffer.sniff(self.sample9) + self.assertTrue(dialect.doublequote) + +class NUL: + def write(s, *args): + pass + writelines = write + +@unittest.skipUnless(hasattr(sys, "gettotalrefcount"), + 'requires sys.gettotalrefcount()') +class TestLeaks(unittest.TestCase): + def test_create_read(self): + delta = 0 + lastrc = sys.gettotalrefcount() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + csv.reader(["a,b,c\r\n"]) + csv.reader(["a,b,c\r\n"]) + csv.reader(["a,b,c\r\n"]) + delta = rc-lastrc + lastrc = rc + # if csv.reader() leaks, last delta should be 3 or more + self.assertLess(delta, 3) + + def test_create_write(self): + delta = 0 + lastrc = sys.gettotalrefcount() + s = NUL() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + csv.writer(s) + csv.writer(s) + csv.writer(s) + delta = rc-lastrc + lastrc = rc + # if csv.writer() leaks, last delta should be 3 or more + self.assertLess(delta, 3) + + def test_read(self): + delta = 0 + rows = ["a,b,c\r\n"]*5 + lastrc = sys.gettotalrefcount() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + rdr = csv.reader(rows) + for row in rdr: + pass + delta = rc-lastrc + lastrc = rc + # if reader leaks during read, delta should be 5 or more + self.assertLess(delta, 5) + + def test_write(self): + delta = 0 + rows = [[1,2,3]]*5 + s = NUL() + lastrc = sys.gettotalrefcount() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + writer = csv.writer(s) + for row in rows: + writer.writerow(row) + delta = rc-lastrc + lastrc = rc + # if writer leaks during write, last delta should be 5 or more + self.assertLess(delta, 5) + +class TestUnicode(unittest.TestCase): + + names = ["Martin von Löwis", + "Marc André Lemburg", + "Guido van Rossum", + "François Pinard"] + + def test_unicode_read(self): + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + fileobj.write(",".join(self.names) + "\r\n") + fileobj.seek(0) + reader = csv.reader(fileobj) + self.assertEqual(list(reader), [self.names]) + + + def test_unicode_write(self): + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + writer = csv.writer(fileobj) + writer.writerow(self.names) + expected = ",".join(self.names)+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + +class KeyOrderingTest(unittest.TestCase): + + def test_ordering_for_the_dict_reader_and_writer(self): + resultset = set() + for keys in permutations("abcde"): + with TemporaryFile('w+', newline='', encoding="utf-8") as fileobject: + dw = csv.DictWriter(fileobject, keys) + dw.writeheader() + fileobject.seek(0) + dr = csv.DictReader(fileobject) + kt = tuple(dr.fieldnames) + self.assertEqual(keys, kt) + resultset.add(kt) + # Final sanity check: were all permutations unique? + self.assertEqual(len(resultset), 120, "Key ordering: some key permutations not collected (expected 120)") + + def test_ordered_dict_reader(self): + data = dedent('''\ + FirstName,LastName + Eric,Idle + Graham,Chapman,Over1,Over2 + + Under1 + John,Cleese + ''').splitlines() + + self.assertEqual(list(csv.DictReader(data)), + [OrderedDict([('FirstName', 'Eric'), ('LastName', 'Idle')]), + OrderedDict([('FirstName', 'Graham'), ('LastName', 'Chapman'), + (None, ['Over1', 'Over2'])]), + OrderedDict([('FirstName', 'Under1'), ('LastName', None)]), + OrderedDict([('FirstName', 'John'), ('LastName', 'Cleese')]), + ]) + + self.assertEqual(list(csv.DictReader(data, restkey='OtherInfo')), + [OrderedDict([('FirstName', 'Eric'), ('LastName', 'Idle')]), + OrderedDict([('FirstName', 'Graham'), ('LastName', 'Chapman'), + ('OtherInfo', ['Over1', 'Over2'])]), + OrderedDict([('FirstName', 'Under1'), ('LastName', None)]), + OrderedDict([('FirstName', 'John'), ('LastName', 'Cleese')]), + ]) + + del data[0] # Remove the header row + self.assertEqual(list(csv.DictReader(data, fieldnames=['fname', 'lname'])), + [OrderedDict([('fname', 'Eric'), ('lname', 'Idle')]), + OrderedDict([('fname', 'Graham'), ('lname', 'Chapman'), + (None, ['Over1', 'Over2'])]), + OrderedDict([('fname', 'Under1'), ('lname', None)]), + OrderedDict([('fname', 'John'), ('lname', 'Cleese')]), + ]) + + +class MiscTestCase(unittest.TestCase): + def test__all__(self): + extra = {'__doc__', '__version__'} + support.check__all__(self, csv, ('csv', '_csv'), extra=extra) + + def test_subclassable(self): + # issue 44089 + class Foo(csv.Error): ... + + @support.cpython_only + def test_disallow_instantiation(self): + _csv = import_helper.import_module("_csv") + for tp in _csv.Reader, _csv.Writer: + with self.subTest(tp=tp): + check_disallow_instantiation(self, tp) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index f6c11a4aad..187270d5b6 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -420,6 +420,9 @@ def test_non_ascii(self): self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007") self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007") + # TODO: RUSTPYTHON formatting does not support locales + # See https://github.com/RustPython/RustPython/issues/5181 + @unittest.skip("formatting does not support locales") def test_locale(self): try: oldloc = locale.setlocale(locale.LC_ALL) diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index e8c126ddc4..7e632efa4c 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -21,6 +21,8 @@ from test.support import threading_helper from test.support import socket_helper from test.support import warnings_helper +from test.support import asynchat +from test.support import asyncore from test.support.socket_helper import HOST, HOSTv6 import sys @@ -992,11 +994,11 @@ def test_context(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE, + self.assertRaises(TypeError, ftplib.FTP_TLS, keyfile=CERTFILE, context=ctx) - self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + self.assertRaises(TypeError, ftplib.FTP_TLS, certfile=CERTFILE, context=ctx) - self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + self.assertRaises(TypeError, ftplib.FTP_TLS, certfile=CERTFILE, keyfile=CERTFILE, context=ctx) self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) @@ -1160,18 +1162,10 @@ def test__all__(self): support.check__all__(self, ftplib, not_exported=not_exported) -def test_main(): - tests = [TestFTPClass, TestTimeouts, - TestIPv6Environment, - TestTLS_FTPClassMixin, TestTLS_FTPClass, - MiscTestCase] - +def setUpModule(): thread_info = threading_helper.threading_setup() - try: - support.run_unittest(*tests) - finally: - threading_helper.threading_cleanup(*thread_info) + unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index bc2e02528d..8e1a4a204c 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -389,6 +389,18 @@ def test_with_digestmod_no_default(self): with self.assertRaisesRegex(TypeError, r'required.*digestmod'): hmac.HMAC(key, msg=data, digestmod='') + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_with_fallback(self): + cache = getattr(hashlib, '__builtin_constructor_cache') + try: + cache['foo'] = hashlib.sha256 + hexdigest = hmac.digest(b'key', b'message', 'foo').hex() + expected = '6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a' + self.assertEqual(hexdigest, expected) + finally: + cache.pop('foo') + class ConstructorTestCase(unittest.TestCase): diff --git a/Lib/test/test_importlib/builtin/test_loader.py b/Lib/test/test_importlib/builtin/test_loader.py index 7e9d1b1960..5719fd79c6 100644 --- a/Lib/test/test_importlib/builtin/test_loader.py +++ b/Lib/test/test_importlib/builtin/test_loader.py @@ -67,9 +67,10 @@ def test_already_imported(self): self.assertEqual(cm.exception.name, module_name) -(Frozen_LoaderTests, - Source_LoaderTests - ) = util.test_both(LoaderTests, machinery=machinery) +# TODO: RUSTPYTHON +# (Frozen_LoaderTests, +# Source_LoaderTests +# ) = util.test_both(LoaderTests, machinery=machinery) @unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py index b7e38c2334..1db738d228 100644 --- a/Lib/test/test_operator.py +++ b/Lib/test/test_operator.py @@ -208,6 +208,9 @@ def test_indexOf(self): nan = float("nan") self.assertEqual(operator.indexOf([nan, nan, 21], nan), 0) self.assertEqual(operator.indexOf([{}, 1, {}, 2], {}), 0) + it = iter('leave the iterator at exactly the position after the match') + self.assertEqual(operator.indexOf(it, 'a'), 2) + self.assertEqual(next(it), 'v') def test_invert(self): operator = self.module diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py index c7b9893943..6ea7e7db2c 100644 --- a/Lib/test/test_pprint.py +++ b/Lib/test/test_pprint.py @@ -203,7 +203,7 @@ def test_knotted(self): def test_unreadable(self): # Not recursive but not readable anyway pp = pprint.PrettyPrinter() - for unreadable in type(3), pprint, pprint.isrecursive: + for unreadable in object(), int, pprint, pprint.isrecursive: # module-level convenience functions self.assertFalse(pprint.isrecursive(unreadable), "expected not isrecursive for %r" % (unreadable,)) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 35f94a4e22..17e9dae8c0 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -2161,12 +2161,16 @@ def testCreateISOTPSocket(self): with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: pass + # TODO: RUSTPYTHON, OSError: bind(): bad family + @unittest.expectedFailure def testTooLongInterfaceName(self): # most systems limit IFNAMSIZ to 16, take 1024 to be sure with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: with self.assertRaisesRegex(OSError, 'interface name too long'): s.bind(('x' * 1024, 1, 2)) + # TODO: RUSTPYTHON, OSError: bind(): bad family + @unittest.expectedFailure def testBind(self): try: with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: diff --git a/Lib/test/test_unpack.py b/Lib/test/test_unpack.py index f5ca1d455b..515ec128a0 100644 --- a/Lib/test/test_unpack.py +++ b/Lib/test/test_unpack.py @@ -162,7 +162,7 @@ def test_extended_oparg_not_ignored(self): ns = {} exec(code, ns) unpack_400 = ns["unpack_400"] - # Warm up the the function for quickening (PEP 659) + # Warm up the function for quickening (PEP 659) for _ in range(30): y = unpack_400(range(400)) self.assertEqual(y, 399) diff --git a/Lib/test/test_uu.py b/Lib/test/test_uu.py index f71d877365..a189d6bc4b 100644 --- a/Lib/test/test_uu.py +++ b/Lib/test/test_uu.py @@ -4,12 +4,13 @@ """ import unittest -from test.support import os_helper +from test.support import os_helper, warnings_helper + +uu = warnings_helper.import_deprecated("uu") import os import stat import sys -import uu import io plaintext = b"The symbols on top of your keyboard are !@#$%^&*()_+|~\n" @@ -57,8 +58,6 @@ def encodedtextwrapped(mode, filename, backtick=False): class UUTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode(self): inp = io.BytesIO(plaintext) out = io.BytesIO() @@ -75,6 +74,7 @@ def test_encode(self): with self.assertRaises(TypeError): uu.encode(inp, out, "t1", 0o644, True) + @os_helper.skip_unless_working_chmod def test_decode(self): for backtick in True, False: inp = io.BytesIO(encodedtextwrapped(0o666, "t1", backtick=backtick)) @@ -138,8 +138,6 @@ def test_garbage_padding(self): decoded = codecs.decode(encodedtext, "uu_codec") self.assertEqual(decoded, plaintext) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_newlines_escaped(self): # Test newlines are escaped with uu.encode inp = io.BytesIO(plaintext) @@ -149,6 +147,34 @@ def test_newlines_escaped(self): uu.encode(inp, out, filename) self.assertIn(safefilename, out.getvalue()) + def test_no_directory_traversal(self): + relative_bad = b"""\ +begin 644 ../../../../../../../../tmp/test1 +$86)C"@`` +` +end +""" + with self.assertRaisesRegex(uu.Error, 'directory'): + uu.decode(io.BytesIO(relative_bad)) + if os.altsep: + relative_bad_bs = relative_bad.replace(b'/', b'\\') + with self.assertRaisesRegex(uu.Error, 'directory'): + uu.decode(io.BytesIO(relative_bad_bs)) + + absolute_bad = b"""\ +begin 644 /tmp/test2 +$86)C"@`` +` +end +""" + with self.assertRaisesRegex(uu.Error, 'directory'): + uu.decode(io.BytesIO(absolute_bad)) + if os.altsep: + absolute_bad_bs = absolute_bad.replace(b'/', b'\\') + with self.assertRaisesRegex(uu.Error, 'directory'): + uu.decode(io.BytesIO(absolute_bad_bs)) + + class UUStdIOTest(unittest.TestCase): def setUp(self): @@ -202,6 +228,8 @@ def test_encode(self): s = fout.read() self.assertEqual(s, encodedtextwrapped(0o644, self.tmpin)) + # decode() calls chmod() + @os_helper.skip_unless_working_chmod def test_decode(self): with open(self.tmpin, 'wb') as f: f.write(encodedtextwrapped(0o644, self.tmpout)) @@ -214,6 +242,7 @@ def test_decode(self): self.assertEqual(s, plaintext) # XXX is there an xp way to verify the mode? + @os_helper.skip_unless_working_chmod def test_decode_filename(self): with open(self.tmpin, 'wb') as f: f.write(encodedtextwrapped(0o644, self.tmpout)) @@ -224,6 +253,7 @@ def test_decode_filename(self): s = f.read() self.assertEqual(s, plaintext) + @os_helper.skip_unless_working_chmod def test_decodetwice(self): # Verify that decode() will refuse to overwrite an existing file with open(self.tmpin, 'wb') as f: @@ -234,8 +264,7 @@ def test_decodetwice(self): with open(self.tmpin, 'rb') as f: self.assertRaises(uu.Error, uu.decode, f) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @os_helper.skip_unless_working_chmod def test_decode_mode(self): # Verify that decode() will set the given mode for the out_file expected_mode = 0o444 diff --git a/Lib/uu.py b/Lib/uu.py index d68d29374a..26bb59ae07 100755 --- a/Lib/uu.py +++ b/Lib/uu.py @@ -26,20 +26,23 @@ """Implementation of the UUencode and UUdecode functions. -encode(in_file, out_file [,name, mode]) -decode(in_file [, out_file, mode]) +encode(in_file, out_file [,name, mode], *, backtick=False) +decode(in_file [, out_file, mode, quiet]) """ import binascii import os import sys +import warnings + +warnings._deprecated(__name__, remove=(3, 13)) __all__ = ["Error", "encode", "decode"] class Error(Exception): pass -def encode(in_file, out_file, name=None, mode=None): +def encode(in_file, out_file, name=None, mode=None, *, backtick=False): """Uuencode file""" # # If in_file is a pathname open it and change defaults @@ -73,15 +76,25 @@ def encode(in_file, out_file, name=None, mode=None): name = '-' if mode is None: mode = 0o666 + + # + # Remove newline chars from name + # + name = name.replace('\n','\\n') + name = name.replace('\r','\\r') + # # Write the data # out_file.write(('begin %o %s\n' % ((mode & 0o777), name)).encode("ascii")) data = in_file.read(45) while len(data) > 0: - out_file.write(binascii.b2a_uu(data)) + out_file.write(binascii.b2a_uu(data, backtick=backtick)) data = in_file.read(45) - out_file.write(b' \nend\n') + if backtick: + out_file.write(b'`\nend\n') + else: + out_file.write(b' \nend\n') finally: for f in opened_files: f.close() @@ -120,7 +133,14 @@ def decode(in_file, out_file=None, mode=None, quiet=False): # If the filename isn't ASCII, what's up with that?!? out_file = hdrfields[2].rstrip(b' \t\r\n\f').decode("ascii") if os.path.exists(out_file): - raise Error('Cannot overwrite existing file: %s' % out_file) + raise Error(f'Cannot overwrite existing file: {out_file}') + if (out_file.startswith(os.sep) or + f'..{os.sep}' in out_file or ( + os.altsep and + (out_file.startswith(os.altsep) or + f'..{os.altsep}' in out_file)) + ): + raise Error(f'Refusing to write to {out_file} due to directory traversal') if mode is None: mode = int(hdrfields[1], 8) # @@ -130,10 +150,7 @@ def decode(in_file, out_file=None, mode=None, quiet=False): out_file = sys.stdout.buffer elif isinstance(out_file, str): fp = open(out_file, 'wb') - try: - os.path.chmod(out_file, mode) - except AttributeError: - pass + os.chmod(out_file, mode) out_file = fp opened_files.append(out_file) # diff --git a/architecture/architecture.md b/architecture/architecture.md index 5b1ae9cc68..a59b6498bf 100644 --- a/architecture/architecture.md +++ b/architecture/architecture.md @@ -101,7 +101,7 @@ Part of the Python standard library that's implemented in Rust. The modules that ### Lib -Python side of the standard libary, copied over (with care) from CPython sourcecode. +Python side of the standard library, copied over (with care) from CPython sourcecode. #### Lib/test diff --git a/benches/benchmarks/pystone.py b/benches/benchmarks/pystone.py index 3faf675ae7..755b4ba85c 100644 --- a/benches/benchmarks/pystone.py +++ b/benches/benchmarks/pystone.py @@ -16,7 +16,7 @@ Version History: - Inofficial version 1.1.1 by Chris Arndt: + Unofficial version 1.1.1 by Chris Arndt: - Make it run under Python 2 and 3 by using "from __future__ import print_function". diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 0b0f2877c7..bf0fbe1dec 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -2941,7 +2941,8 @@ impl ToU32 for usize { #[cfg(test)] mod tests { use super::*; - use rustpython_parser as parser; + use rustpython_parser::ast::Suite; + use rustpython_parser::Parse; use rustpython_parser_core::source_code::LinearLocator; fn compile_exec(source: &str) -> CodeObject { @@ -2952,7 +2953,7 @@ mod tests { "source_path".to_owned(), "".to_owned(), ); - let ast = parser::parse_program(source, "").unwrap(); + let ast = Suite::parse(source, "").unwrap(); let ast = locator.fold(ast).unwrap(); let symbol_scope = SymbolTable::scan_program(&ast).unwrap(); compiler.compile_program(&ast, symbol_scope).unwrap(); diff --git a/examples/parse_folder.rs b/examples/parse_folder.rs index 7774b8afbf..7055a6f831 100644 --- a/examples/parse_folder.rs +++ b/examples/parse_folder.rs @@ -12,7 +12,7 @@ extern crate env_logger; extern crate log; use clap::{App, Arg}; -use rustpython_parser::{self as parser, ast}; +use rustpython_parser::{self as parser, ast, Parse}; use std::{ path::Path, time::{Duration, Instant}, @@ -85,8 +85,8 @@ fn parse_python_file(filename: &Path) -> ParsedFile { }, Ok(source) => { let num_lines = source.lines().count(); - let result = parser::parse_program(&source, &filename.to_string_lossy()) - .map_err(|e| e.to_string()); + let result = + ast::Suite::parse(&source, &filename.to_string_lossy()).map_err(|e| e.to_string()); ParsedFile { // filename: Box::new(filename.to_path_buf()), // code: source.to_string(), diff --git a/stdlib/src/csv.rs b/stdlib/src/csv.rs index bee3fd5faa..96aa1c1fe0 100644 --- a/stdlib/src/csv.rs +++ b/stdlib/src/csv.rs @@ -4,15 +4,18 @@ pub(crate) use _csv::make_module; mod _csv { use crate::common::lock::PyMutex; use crate::vm::{ - builtins::{PyStr, PyTypeRef}, - function::{ArgIterable, ArgumentError, FromArgs, FuncArgs}, - match_class, + builtins::{PyBaseExceptionRef, PyInt, PyNone, PyStr, PyType, PyTypeError, PyTypeRef}, + function::{ArgIterable, ArgumentError, FromArgs, FuncArgs, OptionalArg}, protocol::{PyIter, PyIterReturn}, - types::{IterNext, Iterable, SelfIter}, - AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, + types::{Constructor, IterNext, Iterable, SelfIter}, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; + use csv_core::Terminator; use itertools::{self, Itertools}; - use std::fmt; + use once_cell::sync::Lazy; + use parking_lot::Mutex; + use rustpython_vm::match_class; + use std::{collections::HashMap, fmt}; #[pyattr] const QUOTE_MINIMAL: i32 = QuoteStyle::Minimal as i32; @@ -22,6 +25,12 @@ mod _csv { const QUOTE_NONNUMERIC: i32 = QuoteStyle::Nonnumeric as i32; #[pyattr] const QUOTE_NONE: i32 = QuoteStyle::None as i32; + #[pyattr] + const QUOTE_STRINGS: i32 = QuoteStyle::Strings as i32; + #[pyattr] + const QUOTE_NOTNULL: i32 = QuoteStyle::Notnull as i32; + #[pyattr(name = "__version__")] + const __VERSION__: &str = "1.0"; #[pyattr(name = "Error", once)] fn error(vm: &VirtualMachine) -> PyTypeRef { @@ -32,13 +41,334 @@ mod _csv { ) } + static GLOBAL_HASHMAP: Lazy>> = Lazy::new(|| { + let m = HashMap::new(); + Mutex::new(m) + }); + static GLOBAL_FIELD_LIMIT: Lazy> = Lazy::new(|| Mutex::new(131072)); + + fn new_csv_error(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { + vm.new_exception_msg(super::_csv::error(vm), msg) + } + + #[pyattr] + #[pyclass(module = "csv", name = "Dialect")] + #[derive(Debug, PyPayload, Clone, Copy)] + struct PyDialect { + delimiter: u8, + quotechar: Option, + escapechar: Option, + doublequote: bool, + skipinitialspace: bool, + lineterminator: csv_core::Terminator, + quoting: QuoteStyle, + strict: bool, + } + impl Constructor for PyDialect { + type Args = PyObjectRef; + + fn py_new(cls: PyTypeRef, ctx: Self::Args, vm: &VirtualMachine) -> PyResult { + PyDialect::try_from_object(vm, ctx)? + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + #[pyclass(with(Constructor))] + impl PyDialect { + #[pygetset] + fn delimiter(&self, vm: &VirtualMachine) -> PyRef { + vm.ctx.new_str(format!("{}", self.delimiter as char)) + } + #[pygetset] + fn quotechar(&self, vm: &VirtualMachine) -> Option> { + Some(vm.ctx.new_str(format!("{}", self.quotechar? as char))) + } + #[pygetset] + fn doublequote(&self) -> bool { + self.doublequote + } + #[pygetset] + fn skipinitialspace(&self) -> bool { + self.skipinitialspace + } + #[pygetset] + fn lineterminator(&self, vm: &VirtualMachine) -> PyRef { + match self.lineterminator { + Terminator::CRLF => vm.ctx.new_str("\r\n".to_string()).to_owned(), + Terminator::Any(t) => vm.ctx.new_str(format!("{}", t as char)).to_owned(), + _ => unreachable!(), + } + } + #[pygetset] + fn quoting(&self) -> isize { + self.quoting.into() + } + #[pygetset] + fn escapechar(&self, vm: &VirtualMachine) -> Option> { + Some(vm.ctx.new_str(format!("{}", self.escapechar? as char))) + } + #[pygetset(name = "strict")] + fn get_strict(&self) -> bool { + self.strict + } + } + /// Parses the delimiter from a Python object and returns its ASCII value. + /// + /// This function attempts to extract the 'delimiter' attribute from the given Python object and ensures that the attribute is a single-character string. If successful, it returns the ASCII value of the character. If the attribute is not a single-character string, an error is returned. + /// + /// # Arguments + /// + /// * `vm` - A reference to the VirtualMachine, used for executing Python code and manipulating Python objects. + /// * `obj` - A reference to the PyObjectRef from which the 'delimiter' attribute is to be parsed. + /// + /// # Returns + /// + /// If successful, returns a `PyResult` representing the ASCII value of the 'delimiter' attribute. If unsuccessful, returns a `PyResult` containing an error message. + /// + /// # Errors + /// + /// This function can return the following errors: + /// + /// * If the 'delimiter' attribute is not a single-character string, a type error is returned. + /// * If the 'obj' is not of string type and does not have a 'delimiter' attribute, a type error is returned. + fn parse_delimiter_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + if let Ok(attr) = obj.get_attr("delimiter", vm) { + parse_delimiter_from_obj(vm, &attr) + } else { + match_class!(match obj.clone() { + s @ PyStr => { + Ok(s.as_str().bytes().exactly_one().map_err(|_| { + let msg = r#""delimiter" must be a 1-character string"#; + vm.new_type_error(msg.to_owned()) + })?) + } + attr => { + let msg = format!("\"delimiter\" must be string, not {}", attr.class()); + Err(vm.new_type_error(msg)) + } + }) + } + } + fn parse_quotechar_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + match_class!(match obj.get_attr("quotechar", vm)? { + s @ PyStr => { + Ok(Some(s.as_str().bytes().exactly_one().map_err(|_| { + vm.new_exception_msg( + super::_csv::error(vm), + r#""quotechar" must be a 1-character string"#.to_owned(), + ) + })?)) + } + _n @ PyNone => { + Ok(None) + } + _ => { + Err(vm.new_exception_msg( + super::_csv::error(vm), + r#""quotechar" must be string or None, not int"#.to_owned(), + )) + } + }) + } + fn parse_escapechar_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + match_class!(match obj.get_attr("escapechar", vm)? { + s @ PyStr => { + Ok(Some(s.as_str().bytes().exactly_one().map_err(|_| { + vm.new_exception_msg( + super::_csv::error(vm), + r#""escapechar" must be a 1-character string"#.to_owned(), + ) + })?)) + } + _n @ PyNone => { + Ok(None) + } + attr => { + let msg = format!( + "\"escapechar\" must be string or None, not {}", + attr.class() + ); + Err(vm.new_type_error(msg.to_owned())) + } + }) + } + fn prase_lineterminator_from_obj( + vm: &VirtualMachine, + obj: &PyObjectRef, + ) -> PyResult { + match_class!(match obj.get_attr("lineterminator", vm)? { + s @ PyStr => { + Ok(if s.as_str().as_bytes().eq(b"\r\n") { + csv_core::Terminator::CRLF + } else if let Some(t) = s.as_str().as_bytes().first() { + // Due to limitations in the current implementation within csv_core + // the support for multiple characters in lineterminator is not complete. + // only capture the first character + csv_core::Terminator::Any(*t) + } else { + return Err(vm.new_exception_msg( + super::_csv::error(vm), + r#""lineterminator" must be a string"#.to_owned(), + )); + }) + } + _ => { + let msg = "\"lineterminator\" must be a string".to_string(); + Err(vm.new_type_error(msg.to_owned())) + } + }) + } + fn prase_quoting_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + match_class!(match obj.get_attr("quoting", vm)? { + i @ PyInt => { + Ok(i.try_to_primitive::(vm)?.try_into().map_err(|_| { + let msg = r#"bad "quoting" value"#; + vm.new_type_error(msg.to_owned()) + })?) + } + attr => { + let msg = format!("\"quoting\" must be string or None, not {}", attr.class()); + Err(vm.new_type_error(msg.to_owned())) + } + }) + } + impl TryFromObject for PyDialect { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let delimiter = parse_delimiter_from_obj(vm, &obj)?; + let quotechar = parse_quotechar_from_obj(vm, &obj)?; + let escapechar = parse_escapechar_from_obj(vm, &obj)?; + let doublequote = obj.get_attr("doublequote", vm)?.try_to_bool(vm)?; + let skipinitialspace = obj.get_attr("skipinitialspace", vm)?.try_to_bool(vm)?; + let lineterminator = prase_lineterminator_from_obj(vm, &obj)?; + let quoting = prase_quoting_from_obj(vm, &obj)?; + let strict = if let Ok(t) = obj.get_attr("strict", vm) { + t.try_to_bool(vm).unwrap_or(false) + } else { + false + }; + + Ok(Self { + delimiter, + quotechar, + escapechar, + doublequote, + skipinitialspace, + lineterminator, + quoting, + strict, + }) + } + } + + #[pyfunction] + fn register_dialect( + name: PyObjectRef, + dialect: OptionalArg, + opts: FormatOptions, + // TODO: handle quote style, etc + mut _rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(name) = name.payload_if_subclass::(vm) else { + return Err(vm.new_type_error("argument 0 must be a string".to_string())); + }; + let mut dialect = match dialect { + OptionalArg::Present(d) => PyDialect::try_from_object(vm, d) + .map_err(|_| vm.new_type_error("argument 1 must be a dialect object".to_owned()))?, + OptionalArg::Missing => opts.result(vm)?, + }; + opts.update_pydialect(&mut dialect); + GLOBAL_HASHMAP + .lock() + .insert(name.as_str().to_owned(), dialect); + Ok(()) + } + + #[pyfunction] + fn get_dialect( + name: PyObjectRef, + mut _rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult { + let Some(name) = name.payload_if_subclass::(vm) else { + return Err(vm.new_exception_msg( + super::_csv::error(vm), + format!("argument 0 must be a string, not '{}'", name.class()), + )); + }; + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name.as_str()) { + return Ok(*dialect); + } + Err(vm.new_exception_msg(super::_csv::error(vm), "unknown dialect".to_string())) + } + + #[pyfunction] + fn unregister_dialect( + name: PyObjectRef, + mut _rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(name) = name.payload_if_subclass::(vm) else { + return Err(vm.new_exception_msg( + super::_csv::error(vm), + format!("argument 0 must be a string, not '{}'", name.class()), + )); + }; + let mut g = GLOBAL_HASHMAP.lock(); + if let Some(_removed) = g.remove(name.as_str()) { + return Ok(()); + } + Err(vm.new_exception_msg(super::_csv::error(vm), "unknown dialect".to_string())) + } + + #[pyfunction] + fn list_dialects( + rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult { + if !rest.args.is_empty() || !rest.kwargs.is_empty() { + return Err(vm.new_type_error("too many argument".to_string())); + } + let g = GLOBAL_HASHMAP.lock(); + let t = g + .keys() + .cloned() + .map(|x| vm.ctx.new_str(x).into()) + .collect_vec(); + // .iter().map(|x| vm.ctx.new_str(x.clone()).into_pyobject(vm)).collect_vec(); + Ok(vm.ctx.new_list(t)) + } + + #[pyfunction] + fn field_size_limit(rest: FuncArgs, vm: &VirtualMachine) -> PyResult { + let old_size = GLOBAL_FIELD_LIMIT.lock().to_owned(); + if !rest.args.is_empty() { + let arg_len = rest.args.len(); + if arg_len != 1 { + return Err(vm.new_type_error( + format!( + "field_size_limit() takes at most 1 argument ({} given)", + arg_len + ) + .to_string(), + )); + } + let Ok(new_size) = rest.args.first().unwrap().try_int(vm) else { + return Err(vm.new_type_error("limit must be an integer".to_string())); + }; + *GLOBAL_FIELD_LIMIT.lock() = new_size.try_to_primitive::(vm)?; + } + Ok(old_size) + } + #[pyfunction] fn reader( iter: PyIter, options: FormatOptions, // TODO: handle quote style, etc _rest: FuncArgs, - _vm: &VirtualMachine, + vm: &VirtualMachine, ) -> PyResult { Ok(Reader { iter, @@ -46,7 +376,11 @@ mod _csv { buffer: vec![0; 1024], output_ends: vec![0; 16], reader: options.to_reader(), + skipinitialspace: options.get_skipinitialspace(), + delimiter: options.get_delimiter(), + line_num: 0, }), + dialect: options.result(vm)?, }) } @@ -72,6 +406,7 @@ mod _csv { buffer: vec![0; 1024], writer: options.to_writer(), }), + dialect: options.result(vm)?, }) } @@ -82,67 +417,482 @@ mod _csv { } #[repr(i32)] + #[derive(Debug, Clone, Copy)] pub enum QuoteStyle { Minimal = 0, All = 1, Nonnumeric = 2, None = 3, + Strings = 4, + Notnull = 5, + } + impl From for csv_core::QuoteStyle { + fn from(val: QuoteStyle) -> Self { + match val { + QuoteStyle::Minimal => csv_core::QuoteStyle::Always, + QuoteStyle::All => csv_core::QuoteStyle::Always, + QuoteStyle::Nonnumeric => csv_core::QuoteStyle::NonNumeric, + QuoteStyle::None => csv_core::QuoteStyle::Never, + QuoteStyle::Strings => todo!(), + QuoteStyle::Notnull => todo!(), + } + } + } + impl TryFromObject for QuoteStyle { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let num = obj.try_int(vm)?.try_to_primitive::(vm)?; + num.try_into().map_err(|_| { + vm.new_value_error( + "can not convert to QuoteStyle enum from input argument".to_string(), + ) + }) + } + } + impl TryFrom for QuoteStyle { + type Error = PyTypeError; + fn try_from(num: isize) -> Result { + match num { + 0 => Ok(QuoteStyle::Minimal), + 1 => Ok(QuoteStyle::All), + 2 => Ok(QuoteStyle::Nonnumeric), + 3 => Ok(QuoteStyle::None), + 4 => Ok(QuoteStyle::Strings), + 5 => Ok(QuoteStyle::Notnull), + _ => Err(PyTypeError {}), + } + } + } + impl From for isize { + fn from(val: QuoteStyle) -> Self { + match val { + QuoteStyle::Minimal => 0, + QuoteStyle::All => 1, + QuoteStyle::Nonnumeric => 2, + QuoteStyle::None => 3, + QuoteStyle::Strings => 4, + QuoteStyle::Notnull => 5, + } + } + } + + enum DialectItem { + Str(String), + Obj(PyDialect), + None, } struct FormatOptions { - delimiter: u8, - quotechar: u8, + dialect: DialectItem, + delimiter: Option, + quotechar: Option>, + escapechar: Option, + doublequote: Option, + skipinitialspace: Option, + lineterminator: Option, + quoting: Option, + strict: Option, + } + impl Default for FormatOptions { + fn default() -> Self { + FormatOptions { + dialect: DialectItem::None, + delimiter: None, + quotechar: None, + escapechar: None, + doublequote: None, + skipinitialspace: None, + lineterminator: None, + quoting: None, + strict: None, + } + } + } + /// prase a dialect item from a Python argument and returns a `DialectItem` or an `ArgumentError`. + /// + /// This function takes a reference to the VirtualMachine and a PyObjectRef as input and attempts to parse a dialect item from the provided Python argument. It returns a `DialectItem` if successful, or an `ArgumentError` if unsuccessful. + /// + /// # Arguments + /// + /// * `vm` - A reference to the VirtualMachine, used for executing Python code and manipulating Python objects. + /// * `obj` - The PyObjectRef from which the dialect item is to be parsed. + /// + /// # Returns + /// + /// If successful, returns a `Result` representing the parsed dialect item. If unsuccessful, returns an `ArgumentError`. + /// + /// # Errors + /// + /// This function can return the following errors: + /// + /// * If the provided object is a PyStr, it returns a `DialectItem::Str` containing the string value. + /// * If the provided object is PyNone, it returns an `ArgumentError` with the message "InvalidKeywordArgument('dialect')". + /// * If the provided object is a PyType, it attempts to create a PyDialect from the object and returns a `DialectItem::Obj` containing the PyDialect if successful. If unsuccessful, it returns an `ArgumentError` with the message "InvalidKeywordArgument('dialect')". + /// * If the provided object is none of the above types, it attempts to create a PyDialect from the object and returns a `DialectItem::Obj` containing the PyDialect if successful. If unsuccessful, it returns an `ArgumentError` with the message "InvalidKeywordArgument('dialect')". + fn prase_dialect_item_from_arg( + vm: &VirtualMachine, + obj: PyObjectRef, + ) -> Result { + match_class!(match obj { + s @ PyStr => { + Ok(DialectItem::Str(s.as_str().to_string())) + } + PyNone => { + Err(ArgumentError::InvalidKeywordArgument("dialect".to_string())) + } + t @ PyType => { + let temp = t + .as_object() + .call(vec![], vm) + .map_err(|_e| ArgumentError::InvalidKeywordArgument("dialect".to_string()))?; + Ok(DialectItem::Obj( + PyDialect::try_from_object(vm, temp).map_err(|_| { + ArgumentError::InvalidKeywordArgument("dialect".to_string()) + })?, + )) + } + obj => { + if let Ok(cur_dialect_item) = PyDialect::try_from_object(vm, obj) { + Ok(DialectItem::Obj(cur_dialect_item)) + } else { + let msg = "dialect".to_string(); + Err(ArgumentError::InvalidKeywordArgument(msg)) + } + } + }) } impl FromArgs for FormatOptions { fn from_args(vm: &VirtualMachine, args: &mut FuncArgs) -> Result { - let delimiter = if let Some(delimiter) = args.kwargs.remove("delimiter") { - delimiter - .try_to_value::<&str>(vm)? - .bytes() - .exactly_one() - .map_err(|_| { - let msg = r#""delimiter" must be a 1-character string"#; - vm.new_type_error(msg.to_owned()) - })? + let mut res = FormatOptions::default(); + if let Some(dialect) = args.kwargs.remove("dialect") { + res.dialect = prase_dialect_item_from_arg(vm, dialect)?; + } else if let Some(dialect) = args.args.first() { + res.dialect = prase_dialect_item_from_arg(vm, dialect.clone())?; } else { - b',' + res.dialect = DialectItem::None; }; - let quotechar = if let Some(quotechar) = args.kwargs.remove("quotechar") { - quotechar - .try_to_value::<&str>(vm)? - .bytes() - .exactly_one() - .map_err(|_| { + if let Some(delimiter) = args.kwargs.remove("delimiter") { + res.delimiter = Some(parse_delimiter_from_obj(vm, &delimiter)?); + } + + if let Some(escapechar) = args.kwargs.remove("escapechar") { + res.escapechar = match_class!(match escapechar { + s @ PyStr => Some(s.as_str().bytes().exactly_one().map_err(|_| { + let msg = r#""escapechar" must be a 1-character string"#; + vm.new_type_error(msg.to_owned()) + })?), + _ => None, + }) + }; + if let Some(lineterminator) = args.kwargs.remove("lineterminator") { + res.lineterminator = Some(csv_core::Terminator::Any( + lineterminator + .try_to_value::<&str>(vm)? + .bytes() + .exactly_one() + .map_err(|_| { + let msg = r#""lineterminator" must be a 1-character string"#; + vm.new_type_error(msg.to_owned()) + })?, + )) + }; + if let Some(doublequote) = args.kwargs.remove("doublequote") { + res.doublequote = Some(doublequote.try_to_bool(vm).map_err(|_| { + let msg = r#""doublequote" must be a bool"#; + vm.new_type_error(msg.to_owned()) + })?) + }; + if let Some(skipinitialspace) = args.kwargs.remove("skipinitialspace") { + res.skipinitialspace = Some(skipinitialspace.try_to_bool(vm).map_err(|_| { + let msg = r#""skipinitialspace" must be a bool"#; + vm.new_type_error(msg.to_owned()) + })?) + }; + if let Some(quoting) = args.kwargs.remove("quoting") { + res.quoting = match_class!(match quoting { + i @ PyInt => + Some(i.try_to_primitive::(vm)?.try_into().map_err(|_e| { + ArgumentError::InvalidKeywordArgument("quoting".to_string()) + })?), + _ => { + // let msg = r#""quoting" must be a int enum"#; + return Err(ArgumentError::InvalidKeywordArgument("quoting".to_string())); + } + }); + }; + if let Some(quotechar) = args.kwargs.remove("quotechar") { + res.quotechar = match_class!(match quotechar { + s @ PyStr => Some(Some(s.as_str().bytes().exactly_one().map_err(|_| { let msg = r#""quotechar" must be a 1-character string"#; vm.new_type_error(msg.to_owned()) - })? - } else { - b'"' + })?)), + PyNone => { + if let Some(QuoteStyle::All) = res.quoting { + let msg = "quotechar must be set if quoting enabled"; + return Err(ArgumentError::Exception( + vm.new_type_error(msg.to_owned()), + )); + } + Some(None) + } + _o => { + let msg = r#"quotechar"#; + return Err( + rustpython_vm::function::ArgumentError::InvalidKeywordArgument( + msg.to_string(), + ), + ); + } + }) + }; + if let Some(strict) = args.kwargs.remove("strict") { + res.strict = Some(strict.try_to_bool(vm).map_err(|_| { + let msg = r#""strict" must be a int enum"#; + vm.new_type_error(msg.to_owned()) + })?) }; - Ok(FormatOptions { - delimiter, - quotechar, - }) + if let Some(last_arg) = args.kwargs.pop() { + let msg = format!( + r#"'{}' is an invalid keyword argument for this function"#, + last_arg.0 + ); + return Err(rustpython_vm::function::ArgumentError::InvalidKeywordArgument(msg)); + } + Ok(res) } } impl FormatOptions { + fn update_pydialect<'b>(&self, res: &'b mut PyDialect) -> &'b mut PyDialect { + macro_rules! check_and_fill { + ($res:ident, $e:ident) => {{ + if let Some(t) = self.$e { + $res.$e = t; + } + }}; + } + check_and_fill!(res, delimiter); + // check_and_fill!(res, quotechar); + check_and_fill!(res, delimiter); + check_and_fill!(res, doublequote); + check_and_fill!(res, skipinitialspace); + if let Some(t) = self.escapechar { + res.escapechar = Some(t); + }; + if let Some(t) = self.quotechar { + if let Some(u) = t { + res.quotechar = Some(u); + } else { + res.quotechar = None; + } + }; + check_and_fill!(res, quoting); + check_and_fill!(res, lineterminator); + check_and_fill!(res, strict); + res + } + + fn result(&self, vm: &VirtualMachine) -> PyResult { + match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + let mut dialect = *dialect; + self.update_pydialect(&mut dialect); + Ok(dialect) + } else { + Err(new_csv_error(vm, format!("{} is not registed.", name))) + } + // TODO + // Maybe need to update the obj from HashMap + } + DialectItem::Obj(mut o) => { + self.update_pydialect(&mut o); + Ok(o) + } + DialectItem::None => { + let g = GLOBAL_HASHMAP.lock(); + let mut res = *g.get("excel").unwrap(); + self.update_pydialect(&mut res); + Ok(res) + } + } + } + fn get_skipinitialspace(&self) -> bool { + let mut skipinitialspace = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + dialect.skipinitialspace + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + false + } + } + DialectItem::Obj(obj) => obj.skipinitialspace, + _ => false, + }; + if let Some(attr) = self.skipinitialspace { + skipinitialspace = attr + } + skipinitialspace + } + fn get_delimiter(&self) -> u8 { + let mut delimiter = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + dialect.delimiter + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + b',' + } + } + DialectItem::Obj(obj) => obj.delimiter, + _ => b',', + }; + if let Some(attr) = self.delimiter { + delimiter = attr + } + delimiter + } fn to_reader(&self) -> csv_core::Reader { - csv_core::ReaderBuilder::new() - .delimiter(self.delimiter) - .quote(self.quotechar) - .terminator(csv_core::Terminator::CRLF) - .build() + let mut builder = csv_core::ReaderBuilder::new(); + let mut reader = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + let mut builder = builder + .delimiter(dialect.delimiter) + .double_quote(dialect.doublequote); + if let Some(t) = dialect.quotechar { + builder = builder.quote(t); + } + builder + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + &mut builder + } + } + DialectItem::Obj(obj) => { + let mut builder = builder + .delimiter(obj.delimiter) + .double_quote(obj.doublequote); + if let Some(t) = obj.quotechar { + builder = builder.quote(t); + } + builder + } + _ => { + let name = "excel"; + let g = GLOBAL_HASHMAP.lock(); + let dialect = g.get(name).unwrap(); + let mut builder = builder + .delimiter(dialect.delimiter) + .double_quote(dialect.doublequote); + if let Some(quotechar) = dialect.quotechar { + builder = builder.quote(quotechar); + } + builder + } + }; + + if let Some(t) = self.delimiter { + reader = reader.delimiter(t); + } + if let Some(t) = self.quotechar { + if let Some(u) = t { + reader = reader.quote(u); + } else { + reader = reader.quoting(false); + } + } else { + match self.quoting { + Some(QuoteStyle::None) => { + reader = reader.quoting(false); + } + // None => reader = reader.quoting(true), + _ => reader = reader.quoting(true), + } + } + + if let Some(t) = self.lineterminator { + reader = reader.terminator(t); + } + if let Some(t) = self.doublequote { + reader = reader.double_quote(t); + } + if self.escapechar.is_some() { + reader = reader.escape(self.escapechar); + } + reader = match self.lineterminator { + Some(u) => reader.terminator(u), + None => reader.terminator(Terminator::CRLF), + }; + reader.build() } fn to_writer(&self) -> csv_core::Writer { - csv_core::WriterBuilder::new() - .delimiter(self.delimiter) - .quote(self.quotechar) - .terminator(csv_core::Terminator::CRLF) - .build() + let mut builder = csv_core::WriterBuilder::new(); + let mut writer = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + let mut builder = builder + .delimiter(dialect.delimiter) + .double_quote(dialect.doublequote) + .terminator(dialect.lineterminator); + if let Some(t) = dialect.quotechar { + builder = builder.quote(t); + } + builder + + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + &mut builder + } + } + DialectItem::Obj(obj) => { + let mut builder = builder + .delimiter(obj.delimiter) + .double_quote(obj.doublequote) + .terminator(obj.lineterminator); + if let Some(t) = obj.quotechar { + builder = builder.quote(t); + } + builder + } + _ => &mut builder, + }; + if let Some(t) = self.delimiter { + writer = writer.delimiter(t); + } + if let Some(t) = self.quotechar { + if let Some(u) = t { + writer = writer.quote(u); + } else { + todo!() + } + } + if let Some(t) = self.doublequote { + writer = writer.double_quote(t); + } + writer = match self.lineterminator { + Some(u) => writer.terminator(u), + None => writer.terminator(Terminator::CRLF), + }; + if let Some(e) = self.escapechar { + writer = writer.escape(e); + } + if let Some(e) = self.quoting { + writer = writer.quote_style(e.into()); + } + writer.build() } } @@ -150,6 +900,9 @@ mod _csv { buffer: Vec, output_ends: Vec, reader: csv_core::Reader, + skipinitialspace: bool, + delimiter: u8, + line_num: u64, } #[pyclass(no_attr, module = "_csv", name = "reader", traverse)] @@ -158,6 +911,8 @@ mod _csv { iter: PyIter, #[pytraverse(skip)] state: PyMutex, + #[pytraverse(skip)] + dialect: PyDialect, } impl fmt::Debug for Reader { @@ -167,7 +922,16 @@ mod _csv { } #[pyclass(with(IterNext, Iterable))] - impl Reader {} + impl Reader { + #[pygetset] + fn line_num(&self) -> u64 { + self.state.lock().line_num + } + #[pygetset] + fn dialect(&self, _vm: &VirtualMachine) -> PyDialect { + self.dialect + } + } impl SelfIter for Reader {} impl IterNext for Reader { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -176,27 +940,55 @@ mod _csv { PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), }; let string = string.downcast::().map_err(|obj| { - vm.new_type_error(format!( + new_csv_error( + vm, + format!( "iterator should return strings, not {} (the file should be opened in text mode)", obj.class().name() - )) + ), + ) })?; let input = string.as_str().as_bytes(); - + if input.is_empty() || input.starts_with(b"\n") { + return Ok(PyIterReturn::Return(vm.ctx.new_list(vec![]).into())); + } let mut state = zelf.state.lock(); let ReadState { buffer, output_ends, reader, + skipinitialspace, + delimiter, + line_num, } = &mut *state; let mut input_offset = 0; let mut output_offset = 0; let mut output_ends_offset = 0; - + let field_limit = GLOBAL_FIELD_LIMIT.lock().to_owned(); + #[inline] + fn trim_spaces(input: &[u8]) -> &[u8] { + let trimmed_start = input.iter().position(|&x| x != b' ').unwrap_or(input.len()); + let trimmed_end = input + .iter() + .rposition(|&x| x != b' ') + .map(|i| i + 1) + .unwrap_or(0); + &input[trimmed_start..trimmed_end] + } + let input = if *skipinitialspace { + let t = input.split(|x| x == delimiter); + t.map(|x| { + let trimmed = trim_spaces(x); + String::from_utf8(trimmed.to_vec()).unwrap() + }) + .join(format!("{}", *delimiter as char).as_str()) + } else { + String::from_utf8(input.to_vec()).unwrap() + }; loop { let (res, nread, nwritten, nends) = reader.read_record( - &input[input_offset..], + input[input_offset..].as_bytes(), &mut buffer[output_offset..], &mut output_ends[output_ends_offset..], ); @@ -213,9 +1005,10 @@ mod _csv { } } } - let rest = &input[input_offset..]; + let rest = input[input_offset..].as_bytes(); if !rest.iter().all(|&c| matches!(c, b'\r' | b'\n')) { - return Err(vm.new_value_error( + return Err(new_csv_error( + vm, "new-line character seen in unquoted field - \ do you need to open the file in universal-newline mode?" .to_owned(), @@ -223,17 +1016,40 @@ mod _csv { } let mut prev_end = 0; - let out = output_ends[..output_ends_offset] + let out: Vec = output_ends[..output_ends_offset] .iter() .map(|&end| { let range = prev_end..end; + if range.len() > field_limit as usize { + return Err(new_csv_error(vm, "filed too long to read".to_string())); + } prev_end = end; - let s = std::str::from_utf8(&buffer[range]) + let s = std::str::from_utf8(&buffer[range.clone()]) // not sure if this is possible - the input was all strings .map_err(|_e| vm.new_unicode_decode_error("csv not utf8".to_owned()))?; - Ok(vm.ctx.new_str(s).into()) + // Rustpython TODO! + // Incomplete implementation + if let QuoteStyle::Nonnumeric = zelf.dialect.quoting { + if let Ok(t) = + String::from_utf8(trim_spaces(&buffer[range.clone()]).to_vec()) + .unwrap() + .parse::() + { + Ok(vm.ctx.new_int(t).into()) + } else { + Ok(vm.ctx.new_str(s).into()) + } + } else { + Ok(vm.ctx.new_str(s).into()) + } }) .collect::>()?; + // Removes the last null item before the line terminator, if there is a separator before the line terminator, + // todo! + // if out.last().unwrap().length(vm).unwrap() == 0 { + // out.pop(); + // } + *line_num += 1; Ok(PyIterReturn::Return(vm.ctx.new_list(out).into())) } } @@ -249,6 +1065,8 @@ mod _csv { write: PyObjectRef, #[pytraverse(skip)] state: PyMutex, + #[pytraverse(skip)] + dialect: PyDialect, } impl fmt::Debug for Writer { @@ -259,6 +1077,10 @@ mod _csv { #[pyclass] impl Writer { + #[pygetset(name = "dialect")] + fn get_dialect(&self, _vm: &VirtualMachine) -> PyDialect { + self.dialect + } #[pymethod] fn writerow(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut state = self.state.lock(); @@ -277,7 +1099,10 @@ mod _csv { }}; } - let row = ArgIterable::try_from_object(vm, row)?; + let row = ArgIterable::try_from_object(vm, row.clone()).map_err(|_e| { + new_csv_error(vm, format!("\'{}\' object is not iterable", row.class())) + })?; + let mut first_flag = true; for field in row.iter(vm)? { let field: PyObjectRef = field?; let stringified; @@ -289,8 +1114,14 @@ mod _csv { stringified.as_str().as_bytes() } }); - let mut input_offset = 0; + if first_flag { + first_flag = false; + } else { + loop { + handle_res!(writer.delimiter(&mut buffer[buffer_offset..])); + } + } loop { let (res, nread, nwritten) = @@ -298,16 +1129,11 @@ mod _csv { input_offset += nread; handle_res!((res, nwritten)); } - - loop { - handle_res!(writer.delimiter(&mut buffer[buffer_offset..])); - } } loop { handle_res!(writer.terminator(&mut buffer[buffer_offset..])); } - let s = std::str::from_utf8(&buffer[..buffer_offset]) .map_err(|_| vm.new_unicode_decode_error("csv not utf8".to_owned()))?;