Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cf76e1a

Browse files
committed
Issue #6218: Make io.BytesIO and io.StringIO picklable.
1 parent d2bb18b commit cf76e1a

5 files changed

Lines changed: 410 additions & 16 deletions

File tree

Lib/_pyio.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,11 @@ def __init__(self, initial_bytes=None):
765765
self._buffer = buf
766766
self._pos = 0
767767

768+
def __getstate__(self):
769+
if self.closed:
770+
raise ValueError("__getstate__ on closed file")
771+
return self.__dict__.copy()
772+
768773
def getvalue(self):
769774
"""Return the bytes value (contents) of the buffer
770775
"""

Lib/test/test_memoryio.py

Lines changed: 145 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import io
1010
import _pyio as pyio
1111
import sys
12+
import pickle
1213

1314
class MemorySeekTestMixin:
1415

@@ -346,6 +347,42 @@ def test_instance_dict_leak(self):
346347
memio = self.ioclass()
347348
memio.foo = 1
348349

350+
def test_pickling(self):
351+
buf = self.buftype("1234567890")
352+
memio = self.ioclass(buf)
353+
memio.foo = 42
354+
memio.seek(2)
355+
356+
class PickleTestMemIO(self.ioclass):
357+
def __init__(me, initvalue, foo):
358+
self.ioclass.__init__(me, initvalue)
359+
me.foo = foo
360+
# __getnewargs__ is undefined on purpose. This checks that PEP 307
361+
# is used to provide pickling support.
362+
363+
# Pickle expects the class to be on the module level. Here we use a
364+
# little hack to allow the PickleTestMemIO class to derive from
365+
# self.ioclass without having to define all combinations explictly on
366+
# the module-level.
367+
import __main__
368+
PickleTestMemIO.__module__ = '__main__'
369+
__main__.PickleTestMemIO = PickleTestMemIO
370+
submemio = PickleTestMemIO(buf, 80)
371+
submemio.seek(2)
372+
373+
# We only support pickle protocol 2 and onward since we use extended
374+
# __reduce__ API of PEP 307 to provide pickling support.
375+
for proto in range(2, pickle.HIGHEST_PROTOCOL):
376+
for obj in (memio, submemio):
377+
obj2 = pickle.loads(pickle.dumps(obj, protocol=proto))
378+
self.assertEqual(obj.getvalue(), obj2.getvalue())
379+
self.assertEqual(obj.__class__, obj2.__class__)
380+
self.assertEqual(obj.foo, obj2.foo)
381+
self.assertEqual(obj.tell(), obj2.tell())
382+
obj.close()
383+
self.assertRaises(ValueError, pickle.dumps, obj, proto)
384+
del __main__.PickleTestMemIO
385+
349386

350387
class PyBytesIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
351388

@@ -425,13 +462,26 @@ def test_bytes_array(self):
425462
self.assertEqual(memio.getvalue(), buf)
426463

427464

428-
class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
429-
buftype = str
430-
ioclass = pyio.StringIO
431-
UnsupportedOperation = pyio.UnsupportedOperation
432-
EOF = ""
465+
class TextIOTestMixin:
433466

434-
# TextIO-specific behaviour.
467+
def test_relative_seek(self):
468+
memio = self.ioclass()
469+
470+
self.assertRaises(IOError, memio.seek, -1, 1)
471+
self.assertRaises(IOError, memio.seek, 3, 1)
472+
self.assertRaises(IOError, memio.seek, -3, 1)
473+
self.assertRaises(IOError, memio.seek, -1, 2)
474+
self.assertRaises(IOError, memio.seek, 1, 1)
475+
self.assertRaises(IOError, memio.seek, 1, 2)
476+
477+
def test_textio_properties(self):
478+
memio = self.ioclass()
479+
480+
# These are just dummy values but we nevertheless check them for fear
481+
# of unexpected breakage.
482+
self.assertTrue(memio.encoding is None)
483+
self.assertEqual(memio.errors, "strict")
484+
self.assertEqual(memio.line_buffering, False)
435485

436486
def test_newlines_property(self):
437487
memio = self.ioclass(newline=None)
@@ -513,15 +563,13 @@ def test_newline_lf(self):
513563
def test_newline_cr(self):
514564
# newline="\r"
515565
memio = self.ioclass("a\nb\r\nc\rd", newline="\r")
516-
memio.seek(0)
517566
self.assertEqual(memio.read(), "a\rb\r\rc\rd")
518567
memio.seek(0)
519568
self.assertEqual(list(memio), ["a\r", "b\r", "\r", "c\r", "d"])
520569

521570
def test_newline_crlf(self):
522571
# newline="\r\n"
523572
memio = self.ioclass("a\nb\r\nc\rd", newline="\r\n")
524-
memio.seek(0)
525573
self.assertEqual(memio.read(), "a\r\nb\r\r\nc\rd")
526574
memio.seek(0)
527575
self.assertEqual(list(memio), ["a\r\n", "b\r\r\n", "c\rd"])
@@ -539,10 +587,59 @@ def test_newline_argument(self):
539587
self.ioclass(newline=newline)
540588

541589

590+
class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin,
591+
TextIOTestMixin, unittest.TestCase):
592+
buftype = str
593+
ioclass = pyio.StringIO
594+
UnsupportedOperation = pyio.UnsupportedOperation
595+
EOF = ""
596+
597+
598+
class PyStringIOPickleTest(TextIOTestMixin, unittest.TestCase):
599+
"""Test if pickle restores properly the internal state of StringIO.
600+
"""
601+
buftype = str
602+
UnsupportedOperation = pyio.UnsupportedOperation
603+
EOF = ""
604+
605+
class ioclass(pyio.StringIO):
606+
def __new__(cls, *args, **kwargs):
607+
return pickle.loads(pickle.dumps(pyio.StringIO(*args, **kwargs)))
608+
def __init__(self, *args, **kwargs):
609+
pass
610+
611+
542612
class CBytesIOTest(PyBytesIOTest):
543613
ioclass = io.BytesIO
544614
UnsupportedOperation = io.UnsupportedOperation
545615

616+
def test_getstate(self):
617+
memio = self.ioclass()
618+
state = memio.__getstate__()
619+
self.assertEqual(len(state), 3)
620+
bytearray(state[0]) # Check if state[0] supports the buffer interface.
621+
self.assert_(isinstance(state[1], int))
622+
self.assert_(isinstance(state[2], dict) or state[2] is None)
623+
memio.close()
624+
self.assertRaises(ValueError, memio.__getstate__)
625+
626+
def test_setstate(self):
627+
# This checks whether __setstate__ does proper input validation.
628+
memio = self.ioclass()
629+
memio.__setstate__((b"no error", 0, None))
630+
memio.__setstate__((bytearray(b"no error"), 0, None))
631+
memio.__setstate__((b"no error", 0, {'spam': 3}))
632+
self.assertRaises(ValueError, memio.__setstate__, (b"", -1, None))
633+
self.assertRaises(TypeError, memio.__setstate__, ("unicode", 0, None))
634+
self.assertRaises(TypeError, memio.__setstate__, (b"", 0.0, None))
635+
self.assertRaises(TypeError, memio.__setstate__, (b"", 0, 0))
636+
self.assertRaises(TypeError, memio.__setstate__, (b"len-test", 0))
637+
self.assertRaises(TypeError, memio.__setstate__)
638+
self.assertRaises(TypeError, memio.__setstate__, 0)
639+
memio.close()
640+
self.assertRaises(ValueError, memio.__setstate__, (b"closed", 0, None))
641+
642+
546643
class CStringIOTest(PyStringIOTest):
547644
ioclass = io.StringIO
548645
UnsupportedOperation = io.UnsupportedOperation
@@ -561,9 +658,48 @@ def test_widechar(self):
561658
self.assertEqual(memio.tell(), len(buf) * 2)
562659
self.assertEqual(memio.getvalue(), buf + buf)
563660

661+
def test_getstate(self):
662+
memio = self.ioclass()
663+
state = memio.__getstate__()
664+
self.assertEqual(len(state), 4)
665+
self.assert_(isinstance(state[0], str))
666+
self.assert_(isinstance(state[1], str))
667+
self.assert_(isinstance(state[2], int))
668+
self.assert_(isinstance(state[3], dict) or state[3] is None)
669+
memio.close()
670+
self.assertRaises(ValueError, memio.__getstate__)
671+
672+
def test_setstate(self):
673+
# This checks whether __setstate__ does proper input validation.
674+
memio = self.ioclass()
675+
memio.__setstate__(("no error", "\n", 0, None))
676+
memio.__setstate__(("no error", "", 0, {'spam': 3}))
677+
self.assertRaises(ValueError, memio.__setstate__, ("", "f", 0, None))
678+
self.assertRaises(ValueError, memio.__setstate__, ("", "", -1, None))
679+
self.assertRaises(TypeError, memio.__setstate__, (b"", "", 0, None))
680+
self.assertRaises(TypeError, memio.__setstate__, ("", b"", 0, None))
681+
self.assertRaises(TypeError, memio.__setstate__, ("", "", 0.0, None))
682+
self.assertRaises(TypeError, memio.__setstate__, ("", "", 0, 0))
683+
self.assertRaises(TypeError, memio.__setstate__, ("len-test", 0))
684+
self.assertRaises(TypeError, memio.__setstate__)
685+
self.assertRaises(TypeError, memio.__setstate__, 0)
686+
memio.close()
687+
self.assertRaises(ValueError, memio.__setstate__, ("closed", "", 0, None))
688+
689+
690+
class CStringIOPickleTest(PyStringIOPickleTest):
691+
UnsupportedOperation = io.UnsupportedOperation
692+
693+
class ioclass(io.StringIO):
694+
def __new__(cls, *args, **kwargs):
695+
return pickle.loads(pickle.dumps(io.StringIO(*args, **kwargs)))
696+
def __init__(self, *args, **kwargs):
697+
pass
698+
564699

565700
def test_main():
566-
tests = [PyBytesIOTest, PyStringIOTest, CBytesIOTest, CStringIOTest]
701+
tests = [PyBytesIOTest, PyStringIOTest, CBytesIOTest, CStringIOTest,
702+
PyStringIOPickleTest, CStringIOPickleTest]
567703
support.run_unittest(*tests)
568704

569705
if __name__ == '__main__':

Misc/NEWS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ Library
8585
- Issue #4005: Fixed a crash of pydoc when there was a zip file present in
8686
sys.path.
8787

88+
- Issue #6218: io.StringIO and io.BytesIO instances are now picklable.
89+
8890
Extension Modules
8991
-----------------
9092

Modules/_io/bytesio.c

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,120 @@ bytesio_close(bytesio *self)
606606
Py_RETURN_NONE;
607607
}
608608

609+
/* Pickling support.
610+
611+
Note that only pickle protocol 2 and onward are supported since we use
612+
extended __reduce__ API of PEP 307 to make BytesIO instances picklable.
613+
614+
Providing support for protocol < 2 would require the __reduce_ex__ method
615+
which is notably long-winded when defined properly.
616+
617+
For BytesIO, the implementation would similar to one coded for
618+
object.__reduce_ex__, but slightly less general. To be more specific, we
619+
could call bytesio_getstate directly and avoid checking for the presence of
620+
a fallback __reduce__ method. However, we would still need a __newobj__
621+
function to use the efficient instance representation of PEP 307.
622+
*/
623+
624+
static PyObject *
625+
bytesio_getstate(bytesio *self)
626+
{
627+
PyObject *initvalue = bytesio_getvalue(self);
628+
PyObject *dict;
629+
PyObject *state;
630+
631+
if (initvalue == NULL)
632+
return NULL;
633+
if (self->dict == NULL) {
634+
Py_INCREF(Py_None);
635+
dict = Py_None;
636+
}
637+
else {
638+
dict = PyDict_Copy(self->dict);
639+
if (dict == NULL)
640+
return NULL;
641+
}
642+
643+
state = Py_BuildValue("(OnN)", initvalue, self->pos, dict);
644+
Py_DECREF(initvalue);
645+
return state;
646+
}
647+
648+
static PyObject *
649+
bytesio_setstate(bytesio *self, PyObject *state)
650+
{
651+
PyObject *result;
652+
PyObject *position_obj;
653+
PyObject *dict;
654+
Py_ssize_t pos;
655+
656+
assert(state != NULL);
657+
658+
/* We allow the state tuple to be longer than 3, because we may need
659+
someday to extend the object's state without breaking
660+
backward-compatibility. */
661+
if (!PyTuple_Check(state) || Py_SIZE(state) < 3) {
662+
PyErr_Format(PyExc_TypeError,
663+
"%.200s.__setstate__ argument should be 3-tuple, got %.200s",
664+
Py_TYPE(self)->tp_name, Py_TYPE(state)->tp_name);
665+
return NULL;
666+
}
667+
/* Reset the object to its default state. This is only needed to handle
668+
the case of repeated calls to __setstate__. */
669+
self->string_size = 0;
670+
self->pos = 0;
671+
672+
/* Set the value of the internal buffer. If state[0] does not support the
673+
buffer protocol, bytesio_write will raise the appropriate TypeError. */
674+
result = bytesio_write(self, PyTuple_GET_ITEM(state, 0));
675+
if (result == NULL)
676+
return NULL;
677+
Py_DECREF(result);
678+
679+
/* Set carefully the position value. Alternatively, we could use the seek
680+
method instead of modifying self->pos directly to better protect the
681+
object internal state against errneous (or malicious) inputs. */
682+
position_obj = PyTuple_GET_ITEM(state, 1);
683+
if (!PyLong_Check(position_obj)) {
684+
PyErr_Format(PyExc_TypeError,
685+
"second item of state must be an integer, not %.200s",
686+
Py_TYPE(position_obj)->tp_name);
687+
return NULL;
688+
}
689+
pos = PyLong_AsSsize_t(position_obj);
690+
if (pos == -1 && PyErr_Occurred())
691+
return NULL;
692+
if (pos < 0) {
693+
PyErr_SetString(PyExc_ValueError,
694+
"position value cannot be negative");
695+
return NULL;
696+
}
697+
self->pos = pos;
698+
699+
/* Set the dictionary of the instance variables. */
700+
dict = PyTuple_GET_ITEM(state, 2);
701+
if (dict != Py_None) {
702+
if (!PyDict_Check(dict)) {
703+
PyErr_Format(PyExc_TypeError,
704+
"third item of state should be a dict, got a %.200s",
705+
Py_TYPE(dict)->tp_name);
706+
return NULL;
707+
}
708+
if (self->dict) {
709+
/* Alternatively, we could replace the internal dictionary
710+
completely. However, it seems more practical to just update it. */
711+
if (PyDict_Update(self->dict, dict) < 0)
712+
return NULL;
713+
}
714+
else {
715+
Py_INCREF(dict);
716+
self->dict = dict;
717+
}
718+
}
719+
720+
Py_RETURN_NONE;
721+
}
722+
609723
static void
610724
bytesio_dealloc(bytesio *self)
611725
{
@@ -630,9 +744,9 @@ bytesio_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
630744
if (self == NULL)
631745
return NULL;
632746

633-
self->string_size = 0;
634-
self->pos = 0;
635-
self->buf_size = 0;
747+
/* tp_alloc initializes all the fields to zero. So we don't have to
748+
initialize them here. */
749+
636750
self->buf = (char *)PyMem_Malloc(0);
637751
if (self->buf == NULL) {
638752
Py_DECREF(self);
@@ -705,6 +819,8 @@ static struct PyMethodDef bytesio_methods[] = {
705819
{"getvalue", (PyCFunction)bytesio_getvalue, METH_VARARGS, getval_doc},
706820
{"seek", (PyCFunction)bytesio_seek, METH_VARARGS, seek_doc},
707821
{"truncate", (PyCFunction)bytesio_truncate, METH_VARARGS, truncate_doc},
822+
{"__getstate__", (PyCFunction)bytesio_getstate, METH_NOARGS, NULL},
823+
{"__setstate__", (PyCFunction)bytesio_setstate, METH_O, NULL},
708824
{NULL, NULL} /* sentinel */
709825
};
710826

0 commit comments

Comments
 (0)