diff --git a/tests/test_upgrade_pythoncapi.py b/tests/test_upgrade_pythoncapi.py index bf6f648..48bad12 100644 --- a/tests/test_upgrade_pythoncapi.py +++ b/tests/test_upgrade_pythoncapi.py @@ -43,33 +43,12 @@ def reformat(source): class Tests(unittest.TestCase): maxDiff = 80 * 30 - def _test_patch_file(self, tmp_dir): - # test Patcher.patcher() - source = """ - PyTypeObject* - test_type(PyObject *obj, PyTypeObject *type) - { - Py_TYPE(obj) = type; - return Py_TYPE(obj); - } - """ - expected = """ - #include "pythoncapi_compat.h" - - PyTypeObject* - test_type(PyObject *obj, PyTypeObject *type) - { - Py_SET_TYPE(obj, type); - return Py_TYPE(obj); - } - """ - source = reformat(source) - expected = reformat(expected) - + def _patch_file(self, source, tmp_dir=None): + # test Patcher.patcher() filename = tempfile.mktemp(suffix='.c', dir=tmp_dir) old_filename = filename + ".old" try: - with open(filename, "w", encoding="utf-8") as fp: + with open(filename, "w", encoding="utf-8", newline="") as fp: fp.write(source) old_stderr = sys.stderr @@ -93,10 +72,10 @@ def _test_patch_file(self, tmp_dir): sys.stderr = old_stderr sys.argv = old_argv - with open(filename, encoding="utf-8") as fp: + with open(filename, encoding="utf-8", newline="") as fp: new_contents = fp.read() - with open(old_filename, encoding="utf-8") as fp: + with open(old_filename, encoding="utf-8", newline="") as fp: old_contents = fp.read() finally: try: @@ -108,15 +87,53 @@ def _test_patch_file(self, tmp_dir): except FileNotFoundError: pass - self.assertEqual(new_contents, expected) self.assertEqual(old_contents, source) + return new_contents def test_patch_file(self): - self._test_patch_file(None) + source = """ + PyTypeObject* + test_type(PyObject *obj, PyTypeObject *type) + { + Py_TYPE(obj) = type; + return Py_TYPE(obj); + } + """ + expected = """ + #include "pythoncapi_compat.h" + + PyTypeObject* + test_type(PyObject *obj, PyTypeObject *type) + { + Py_SET_TYPE(obj, type); + return Py_TYPE(obj); + } + """ + source = reformat(source) + expected = reformat(expected) + + new_contents = self._patch_file(source) + self.assertEqual(new_contents, expected) - def test_patch_directory(self): with tempfile.TemporaryDirectory() as tmp_dir: - self._test_patch_file(tmp_dir) + new_contents = self._patch_file(source, tmp_dir) + self.assertEqual(new_contents, expected) + + def test_patch_file_preserve_newlines(self): + source = """ + Py_ssize_t get_size(PyVarObject *obj)\r\n\ + \n\ + { return obj->ob_size; }\r\ + """ + expected = """ + Py_ssize_t get_size(PyVarObject *obj)\r\n\ + \n\ + { return Py_SIZE(obj); }\r\ + """ + source = reformat(source) + expected = reformat(expected) + new_contents = self._patch_file(source) + self.assertEqual(new_contents, expected) def check_replace(self, source, expected, **kwargs): source = reformat(source) diff --git a/upgrade_pythoncapi.py b/upgrade_pythoncapi.py index 6640cec..68b5c53 100755 --- a/upgrade_pythoncapi.py +++ b/upgrade_pythoncapi.py @@ -554,13 +554,17 @@ def _get_operations(self, parser): return operations def add_line(self, content, line): - line = line + '\n' + # Use the first matching newline + match = re.search(r'(?:\r\n|\n|\r)', content) + newline = match.group(0) if match else '\n' + + line = line + newline # FIXME: tolerate trailing spaces if line not in content: # FIXME: add macro after the first header comment # FIXME: add macro after includes # FIXME: add macro after: #define PY_SSIZE_T_CLEAN - return line + '\n' + content + return line + newline + content else: return content @@ -601,7 +605,7 @@ def patch_file(self, filename): encoding = "utf-8" errors = "surrogateescape" - with open(filename, encoding=encoding, errors=errors) as fp: + with open(filename, encoding=encoding, errors=errors, newline="") as fp: old_contents = fp.read() new_contents, operations = self._patch(old_contents) @@ -620,7 +624,7 @@ def patch_file(self, filename): # If old_filename already exists, replace it os.replace(filename, old_filename) - with open(filename, "w", encoding=encoding, errors=errors) as fp: + with open(filename, "w", encoding=encoding, errors=errors, newline="") as fp: fp.write(new_contents) self.applied_operations |= set(operations)