548 lines
20 KiB
Python
548 lines
20 KiB
Python
from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
|
|
NAME_MAPPING, REVERSE_NAME_MAPPING)
|
|
import builtins
|
|
import pickle
|
|
import io
|
|
import collections
|
|
import struct
|
|
import sys
|
|
import warnings
|
|
import weakref
|
|
|
|
import doctest
|
|
import unittest
|
|
from test import support
|
|
from test.support import import_helper
|
|
|
|
from test.pickletester import AbstractHookTests
|
|
from test.pickletester import AbstractUnpickleTests
|
|
from test.pickletester import AbstractPickleTests
|
|
from test.pickletester import AbstractPickleModuleTests
|
|
from test.pickletester import AbstractPersistentPicklerTests
|
|
from test.pickletester import AbstractIdentityPersistentPicklerTests
|
|
from test.pickletester import AbstractPicklerUnpicklerObjectTests
|
|
from test.pickletester import AbstractDispatchTableTests
|
|
from test.pickletester import AbstractCustomPicklerClass
|
|
from test.pickletester import BigmemPickleTests
|
|
|
|
try:
|
|
import _pickle
|
|
has_c_implementation = True
|
|
except ImportError:
|
|
has_c_implementation = False
|
|
|
|
|
|
class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
|
|
dump = staticmethod(pickle._dump)
|
|
dumps = staticmethod(pickle._dumps)
|
|
load = staticmethod(pickle._load)
|
|
loads = staticmethod(pickle._loads)
|
|
Pickler = pickle._Pickler
|
|
Unpickler = pickle._Unpickler
|
|
|
|
|
|
class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
|
|
|
|
unpickler = pickle._Unpickler
|
|
bad_stack_errors = (IndexError,)
|
|
truncated_errors = (pickle.UnpicklingError, EOFError,
|
|
AttributeError, ValueError,
|
|
struct.error, IndexError, ImportError)
|
|
|
|
def loads(self, buf, **kwds):
|
|
f = io.BytesIO(buf)
|
|
u = self.unpickler(f, **kwds)
|
|
return u.load()
|
|
|
|
|
|
class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
|
|
|
|
pickler = pickle._Pickler
|
|
unpickler = pickle._Unpickler
|
|
|
|
def dumps(self, arg, proto=None, **kwargs):
|
|
f = io.BytesIO()
|
|
p = self.pickler(f, proto, **kwargs)
|
|
p.dump(arg)
|
|
f.seek(0)
|
|
return bytes(f.read())
|
|
|
|
def loads(self, buf, **kwds):
|
|
f = io.BytesIO(buf)
|
|
u = self.unpickler(f, **kwds)
|
|
return u.load()
|
|
|
|
|
|
class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
|
|
BigmemPickleTests, unittest.TestCase):
|
|
|
|
bad_stack_errors = (pickle.UnpicklingError, IndexError)
|
|
truncated_errors = (pickle.UnpicklingError, EOFError,
|
|
AttributeError, ValueError,
|
|
struct.error, IndexError, ImportError)
|
|
|
|
def dumps(self, arg, protocol=None, **kwargs):
|
|
return pickle.dumps(arg, protocol, **kwargs)
|
|
|
|
def loads(self, buf, **kwds):
|
|
return pickle.loads(buf, **kwds)
|
|
|
|
test_framed_write_sizes_with_delayed_writer = None
|
|
|
|
|
|
class PersistentPicklerUnpicklerMixin(object):
|
|
|
|
def dumps(self, arg, proto=None):
|
|
class PersPickler(self.pickler):
|
|
def persistent_id(subself, obj):
|
|
return self.persistent_id(obj)
|
|
f = io.BytesIO()
|
|
p = PersPickler(f, proto)
|
|
p.dump(arg)
|
|
return f.getvalue()
|
|
|
|
def loads(self, buf, **kwds):
|
|
class PersUnpickler(self.unpickler):
|
|
def persistent_load(subself, obj):
|
|
return self.persistent_load(obj)
|
|
f = io.BytesIO(buf)
|
|
u = PersUnpickler(f, **kwds)
|
|
return u.load()
|
|
|
|
|
|
class PyPersPicklerTests(AbstractPersistentPicklerTests,
|
|
PersistentPicklerUnpicklerMixin, unittest.TestCase):
|
|
|
|
pickler = pickle._Pickler
|
|
unpickler = pickle._Unpickler
|
|
|
|
|
|
class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
|
|
PersistentPicklerUnpicklerMixin, unittest.TestCase):
|
|
|
|
pickler = pickle._Pickler
|
|
unpickler = pickle._Unpickler
|
|
|
|
@support.cpython_only
|
|
def test_pickler_reference_cycle(self):
|
|
def check(Pickler):
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
f = io.BytesIO()
|
|
pickler = Pickler(f, proto)
|
|
pickler.dump('abc')
|
|
self.assertEqual(self.loads(f.getvalue()), 'abc')
|
|
pickler = Pickler(io.BytesIO())
|
|
self.assertEqual(pickler.persistent_id('def'), 'def')
|
|
r = weakref.ref(pickler)
|
|
del pickler
|
|
self.assertIsNone(r())
|
|
|
|
class PersPickler(self.pickler):
|
|
def persistent_id(subself, obj):
|
|
return obj
|
|
check(PersPickler)
|
|
|
|
class PersPickler(self.pickler):
|
|
@classmethod
|
|
def persistent_id(cls, obj):
|
|
return obj
|
|
check(PersPickler)
|
|
|
|
class PersPickler(self.pickler):
|
|
@staticmethod
|
|
def persistent_id(obj):
|
|
return obj
|
|
check(PersPickler)
|
|
|
|
@support.cpython_only
|
|
def test_custom_pickler_dispatch_table_memleak(self):
|
|
# See https://github.com/python/cpython/issues/89988
|
|
|
|
class Pickler(self.pickler):
|
|
def __init__(self, *args, **kwargs):
|
|
self.dispatch_table = table
|
|
super().__init__(*args, **kwargs)
|
|
|
|
class DispatchTable:
|
|
pass
|
|
|
|
table = DispatchTable()
|
|
pickler = Pickler(io.BytesIO())
|
|
self.assertIs(pickler.dispatch_table, table)
|
|
table_ref = weakref.ref(table)
|
|
self.assertIsNotNone(table_ref())
|
|
del pickler
|
|
del table
|
|
support.gc_collect()
|
|
self.assertIsNone(table_ref())
|
|
|
|
|
|
@support.cpython_only
|
|
def test_unpickler_reference_cycle(self):
|
|
def check(Unpickler):
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
|
|
self.assertEqual(unpickler.load(), 'abc')
|
|
unpickler = Unpickler(io.BytesIO())
|
|
self.assertEqual(unpickler.persistent_load('def'), 'def')
|
|
r = weakref.ref(unpickler)
|
|
del unpickler
|
|
self.assertIsNone(r())
|
|
|
|
class PersUnpickler(self.unpickler):
|
|
def persistent_load(subself, pid):
|
|
return pid
|
|
check(PersUnpickler)
|
|
|
|
class PersUnpickler(self.unpickler):
|
|
@classmethod
|
|
def persistent_load(cls, pid):
|
|
return pid
|
|
check(PersUnpickler)
|
|
|
|
class PersUnpickler(self.unpickler):
|
|
@staticmethod
|
|
def persistent_load(pid):
|
|
return pid
|
|
check(PersUnpickler)
|
|
|
|
|
|
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
|
|
|
|
pickler_class = pickle._Pickler
|
|
unpickler_class = pickle._Unpickler
|
|
|
|
|
|
class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
|
|
|
|
pickler_class = pickle._Pickler
|
|
|
|
def get_dispatch_table(self):
|
|
return pickle.dispatch_table.copy()
|
|
|
|
|
|
class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
|
|
|
|
pickler_class = pickle._Pickler
|
|
|
|
def get_dispatch_table(self):
|
|
return collections.ChainMap({}, pickle.dispatch_table)
|
|
|
|
|
|
class PyPicklerHookTests(AbstractHookTests, unittest.TestCase):
|
|
class CustomPyPicklerClass(pickle._Pickler,
|
|
AbstractCustomPicklerClass):
|
|
pass
|
|
pickler_class = CustomPyPicklerClass
|
|
|
|
|
|
if has_c_implementation:
|
|
class CPickleTests(AbstractPickleModuleTests, unittest.TestCase):
|
|
from _pickle import dump, dumps, load, loads, Pickler, Unpickler
|
|
|
|
class CUnpicklerTests(PyUnpicklerTests):
|
|
unpickler = _pickle.Unpickler
|
|
bad_stack_errors = (pickle.UnpicklingError,)
|
|
truncated_errors = (pickle.UnpicklingError,)
|
|
|
|
class CPicklerTests(PyPicklerTests):
|
|
pickler = _pickle.Pickler
|
|
unpickler = _pickle.Unpickler
|
|
|
|
class CPersPicklerTests(PyPersPicklerTests):
|
|
pickler = _pickle.Pickler
|
|
unpickler = _pickle.Unpickler
|
|
|
|
class CIdPersPicklerTests(PyIdPersPicklerTests):
|
|
pickler = _pickle.Pickler
|
|
unpickler = _pickle.Unpickler
|
|
|
|
class CDumpPickle_LoadPickle(PyPicklerTests):
|
|
pickler = _pickle.Pickler
|
|
unpickler = pickle._Unpickler
|
|
|
|
class DumpPickle_CLoadPickle(PyPicklerTests):
|
|
pickler = pickle._Pickler
|
|
unpickler = _pickle.Unpickler
|
|
|
|
class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
|
|
pickler_class = _pickle.Pickler
|
|
unpickler_class = _pickle.Unpickler
|
|
|
|
def test_issue18339(self):
|
|
unpickler = self.unpickler_class(io.BytesIO())
|
|
with self.assertRaises(TypeError):
|
|
unpickler.memo = object
|
|
# used to cause a segfault
|
|
with self.assertRaises(ValueError):
|
|
unpickler.memo = {-1: None}
|
|
unpickler.memo = {1: None}
|
|
|
|
class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
|
|
pickler_class = pickle.Pickler
|
|
def get_dispatch_table(self):
|
|
return pickle.dispatch_table.copy()
|
|
|
|
class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
|
|
pickler_class = pickle.Pickler
|
|
def get_dispatch_table(self):
|
|
return collections.ChainMap({}, pickle.dispatch_table)
|
|
|
|
class CPicklerHookTests(AbstractHookTests, unittest.TestCase):
|
|
class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
|
|
pass
|
|
pickler_class = CustomCPicklerClass
|
|
|
|
@support.cpython_only
|
|
class SizeofTests(unittest.TestCase):
|
|
check_sizeof = support.check_sizeof
|
|
|
|
def test_pickler(self):
|
|
basesize = support.calcobjsize('7P2n3i2n3i2P')
|
|
p = _pickle.Pickler(io.BytesIO())
|
|
self.assertEqual(object.__sizeof__(p), basesize)
|
|
MT_size = struct.calcsize('3nP0n')
|
|
ME_size = struct.calcsize('Pn0P')
|
|
check = self.check_sizeof
|
|
check(p, basesize +
|
|
MT_size + 8 * ME_size + # Minimal memo table size.
|
|
sys.getsizeof(b'x'*4096)) # Minimal write buffer size.
|
|
for i in range(6):
|
|
p.dump(chr(i))
|
|
check(p, basesize +
|
|
MT_size + 32 * ME_size + # Size of memo table required to
|
|
# save references to 6 objects.
|
|
0) # Write buffer is cleared after every dump().
|
|
|
|
def test_unpickler(self):
|
|
basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
|
|
unpickler = _pickle.Unpickler
|
|
P = struct.calcsize('P') # Size of memo table entry.
|
|
n = struct.calcsize('n') # Size of mark table entry.
|
|
check = self.check_sizeof
|
|
for encoding in 'ASCII', 'UTF-16', 'latin-1':
|
|
for errors in 'strict', 'replace':
|
|
u = unpickler(io.BytesIO(),
|
|
encoding=encoding, errors=errors)
|
|
self.assertEqual(object.__sizeof__(u), basesize)
|
|
check(u, basesize +
|
|
32 * P + # Minimal memo table size.
|
|
len(encoding) + 1 + len(errors) + 1)
|
|
|
|
stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
|
|
def check_unpickler(data, memo_size, marks_size):
|
|
dump = pickle.dumps(data)
|
|
u = unpickler(io.BytesIO(dump),
|
|
encoding='ASCII', errors='strict')
|
|
u.load()
|
|
check(u, stdsize + memo_size * P + marks_size * n)
|
|
|
|
check_unpickler(0, 32, 0)
|
|
# 20 is minimal non-empty mark stack size.
|
|
check_unpickler([0] * 100, 32, 20)
|
|
# 128 is memo table size required to save references to 100 objects.
|
|
check_unpickler([chr(i) for i in range(100)], 128, 20)
|
|
def recurse(deep):
|
|
data = 0
|
|
for i in range(deep):
|
|
data = [data, data]
|
|
return data
|
|
check_unpickler(recurse(0), 32, 0)
|
|
check_unpickler(recurse(1), 32, 20)
|
|
check_unpickler(recurse(20), 32, 20)
|
|
check_unpickler(recurse(50), 64, 60)
|
|
check_unpickler(recurse(100), 128, 140)
|
|
|
|
u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
|
|
encoding='ASCII', errors='strict')
|
|
u.load()
|
|
check(u, stdsize + 32 * P + 2 + 1)
|
|
|
|
|
|
ALT_IMPORT_MAPPING = {
|
|
('_elementtree', 'xml.etree.ElementTree'),
|
|
('cPickle', 'pickle'),
|
|
('StringIO', 'io'),
|
|
('cStringIO', 'io'),
|
|
}
|
|
|
|
ALT_NAME_MAPPING = {
|
|
('__builtin__', 'basestring', 'builtins', 'str'),
|
|
('exceptions', 'StandardError', 'builtins', 'Exception'),
|
|
('UserDict', 'UserDict', 'collections', 'UserDict'),
|
|
('socket', '_socketobject', 'socket', 'SocketType'),
|
|
}
|
|
|
|
def mapping(module, name):
|
|
if (module, name) in NAME_MAPPING:
|
|
module, name = NAME_MAPPING[(module, name)]
|
|
elif module in IMPORT_MAPPING:
|
|
module = IMPORT_MAPPING[module]
|
|
return module, name
|
|
|
|
def reverse_mapping(module, name):
|
|
if (module, name) in REVERSE_NAME_MAPPING:
|
|
module, name = REVERSE_NAME_MAPPING[(module, name)]
|
|
elif module in REVERSE_IMPORT_MAPPING:
|
|
module = REVERSE_IMPORT_MAPPING[module]
|
|
return module, name
|
|
|
|
def getmodule(module):
|
|
try:
|
|
return sys.modules[module]
|
|
except KeyError:
|
|
try:
|
|
with warnings.catch_warnings():
|
|
action = 'always' if support.verbose else 'ignore'
|
|
warnings.simplefilter(action, DeprecationWarning)
|
|
__import__(module)
|
|
except AttributeError as exc:
|
|
if support.verbose:
|
|
print("Can't import module %r: %s" % (module, exc))
|
|
raise ImportError
|
|
except ImportError as exc:
|
|
if support.verbose:
|
|
print(exc)
|
|
raise
|
|
return sys.modules[module]
|
|
|
|
def getattribute(module, name):
|
|
obj = getmodule(module)
|
|
for n in name.split('.'):
|
|
obj = getattr(obj, n)
|
|
return obj
|
|
|
|
def get_exceptions(mod):
|
|
for name in dir(mod):
|
|
attr = getattr(mod, name)
|
|
if isinstance(attr, type) and issubclass(attr, BaseException):
|
|
yield name, attr
|
|
|
|
class CompatPickleTests(unittest.TestCase):
|
|
def test_import(self):
|
|
modules = set(IMPORT_MAPPING.values())
|
|
modules |= set(REVERSE_IMPORT_MAPPING)
|
|
modules |= {module for module, name in REVERSE_NAME_MAPPING}
|
|
modules |= {module for module, name in NAME_MAPPING.values()}
|
|
for module in modules:
|
|
try:
|
|
getmodule(module)
|
|
except ImportError:
|
|
pass
|
|
|
|
def test_import_mapping(self):
|
|
for module3, module2 in REVERSE_IMPORT_MAPPING.items():
|
|
with self.subTest((module3, module2)):
|
|
try:
|
|
getmodule(module3)
|
|
except ImportError:
|
|
pass
|
|
if module3[:1] != '_':
|
|
self.assertIn(module2, IMPORT_MAPPING)
|
|
self.assertEqual(IMPORT_MAPPING[module2], module3)
|
|
|
|
def test_name_mapping(self):
|
|
for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
|
|
with self.subTest(((module3, name3), (module2, name2))):
|
|
if (module2, name2) == ('exceptions', 'OSError'):
|
|
attr = getattribute(module3, name3)
|
|
self.assertTrue(issubclass(attr, OSError))
|
|
elif (module2, name2) == ('exceptions', 'ImportError'):
|
|
attr = getattribute(module3, name3)
|
|
self.assertTrue(issubclass(attr, ImportError))
|
|
else:
|
|
module, name = mapping(module2, name2)
|
|
if module3[:1] != '_':
|
|
self.assertEqual((module, name), (module3, name3))
|
|
try:
|
|
attr = getattribute(module3, name3)
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
self.assertEqual(getattribute(module, name), attr)
|
|
|
|
def test_reverse_import_mapping(self):
|
|
for module2, module3 in IMPORT_MAPPING.items():
|
|
with self.subTest((module2, module3)):
|
|
try:
|
|
getmodule(module3)
|
|
except ImportError as exc:
|
|
if support.verbose:
|
|
print(exc)
|
|
if ((module2, module3) not in ALT_IMPORT_MAPPING and
|
|
REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
|
|
for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
|
|
if (module3, module2) == (m3, m2):
|
|
break
|
|
else:
|
|
self.fail('No reverse mapping from %r to %r' %
|
|
(module3, module2))
|
|
module = REVERSE_IMPORT_MAPPING.get(module3, module3)
|
|
module = IMPORT_MAPPING.get(module, module)
|
|
self.assertEqual(module, module3)
|
|
|
|
def test_reverse_name_mapping(self):
|
|
for (module2, name2), (module3, name3) in NAME_MAPPING.items():
|
|
with self.subTest(((module2, name2), (module3, name3))):
|
|
try:
|
|
attr = getattribute(module3, name3)
|
|
except ImportError:
|
|
pass
|
|
module, name = reverse_mapping(module3, name3)
|
|
if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
|
|
self.assertEqual((module, name), (module2, name2))
|
|
module, name = mapping(module, name)
|
|
self.assertEqual((module, name), (module3, name3))
|
|
|
|
def test_exceptions(self):
|
|
self.assertEqual(mapping('exceptions', 'StandardError'),
|
|
('builtins', 'Exception'))
|
|
self.assertEqual(mapping('exceptions', 'Exception'),
|
|
('builtins', 'Exception'))
|
|
self.assertEqual(reverse_mapping('builtins', 'Exception'),
|
|
('exceptions', 'Exception'))
|
|
self.assertEqual(mapping('exceptions', 'OSError'),
|
|
('builtins', 'OSError'))
|
|
self.assertEqual(reverse_mapping('builtins', 'OSError'),
|
|
('exceptions', 'OSError'))
|
|
|
|
for name, exc in get_exceptions(builtins):
|
|
with self.subTest(name):
|
|
if exc in (BlockingIOError,
|
|
ResourceWarning,
|
|
StopAsyncIteration,
|
|
RecursionError,
|
|
EncodingWarning):
|
|
continue
|
|
if exc is not OSError and issubclass(exc, OSError):
|
|
self.assertEqual(reverse_mapping('builtins', name),
|
|
('exceptions', 'OSError'))
|
|
elif exc is not ImportError and issubclass(exc, ImportError):
|
|
self.assertEqual(reverse_mapping('builtins', name),
|
|
('exceptions', 'ImportError'))
|
|
self.assertEqual(mapping('exceptions', name),
|
|
('exceptions', name))
|
|
else:
|
|
self.assertEqual(reverse_mapping('builtins', name),
|
|
('exceptions', name))
|
|
self.assertEqual(mapping('exceptions', name),
|
|
('builtins', name))
|
|
|
|
def test_multiprocessing_exceptions(self):
|
|
module = import_helper.import_module('multiprocessing.context')
|
|
for name, exc in get_exceptions(module):
|
|
with self.subTest(name):
|
|
self.assertEqual(reverse_mapping('multiprocessing.context', name),
|
|
('multiprocessing', name))
|
|
self.assertEqual(mapping('multiprocessing', name),
|
|
('multiprocessing.context', name))
|
|
|
|
|
|
def load_tests(loader, tests, pattern):
|
|
tests.addTest(doctest.DocTestSuite())
|
|
return tests
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|