serpent.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. """
  2. ast.literal_eval() compatible object tree serialization.
  3. Serpent serializes an object tree into bytes (utf-8 encoded string) that can
  4. be decoded and then passed as-is to ast.literal_eval() to rebuild it as the
  5. original object tree. As such it is safe to send serpent data to other
  6. machines over the network for instance (because only 'safe' literals are
  7. encoded).
  8. Compatible with recent Python 3 versions
  9. Serpent handles several special Python types to make life easier:
  10. - bytes, bytearrays, memoryview --> string, base-64
  11. (you'll have to manually un-base64 them though)
  12. - uuid.UUID, datetime.{datetime, date, time, timespan} --> appropriate string/number
  13. - decimal.Decimal --> string (to not lose precision)
  14. - array.array typecode 'u' --> string
  15. - array.array other typecode --> list
  16. - Exception --> dict with some fields of the exception (message, args)
  17. - collections module types --> mostly equivalent primitive types or dict
  18. - enums --> the value of the enum
  19. - all other types --> dict with __getstate__ or vars() of the object
  20. Notes:
  21. The serializer is not thread-safe. Make sure you're not making changes
  22. to the object tree that is being serialized, and don't use the same
  23. serializer in different threads.
  24. Because the serialized format is just valid Python source code, it can
  25. contain comments.
  26. Floats +inf and -inf are handled via a trick, Float 'nan' cannot be handled
  27. and is represented by the special value: {'__class__':'float','value':'nan'}
  28. We chose not to encode it as just the string 'NaN' because that could cause
  29. memory issues when used in multiplications.
  30. Copyright by Irmen de Jong (irmen@razorvine.net)
  31. Software license: "MIT software license". See http://opensource.org/licenses/MIT
  32. """
  33. import ast
  34. import base64
  35. import sys
  36. import gc
  37. import decimal
  38. import datetime
  39. import uuid
  40. import array
  41. import math
  42. import numbers
  43. import codecs
  44. import collections
  45. import enum
  46. from collections.abc import KeysView, ValuesView, ItemsView
  47. __version__ = "1.40"
  48. __all__ = ["dump", "dumps", "load", "loads", "register_class", "unregister_class", "tobytes"]
  49. def dumps(obj, indent=False, module_in_classname=False, bytes_repr=False):
  50. """
  51. Serialize object tree to bytes.
  52. indent = indent the output over multiple lines (default=false)
  53. module_in_classname = include module prefix for class names or only use the class name itself
  54. bytes_repr = should the bytes literal value representation be used instead of base-64 encoding for bytes types?
  55. """
  56. return Serializer(indent, module_in_classname, bytes_repr).serialize(obj)
  57. def dump(obj, file, indent=False, module_in_classname=False, bytes_repr=False):
  58. """
  59. Serialize object tree to a file.
  60. indent = indent the output over multiple lines (default=false)
  61. module_in_classname = include module prefix for class names or only use the class name itself
  62. bytes_repr = should the bytes literal value representation be used instead of base-64 encoding for bytes types?
  63. """
  64. file.write(dumps(obj, indent=indent, module_in_classname=module_in_classname, bytes_repr=bytes_repr))
  65. def loads(serialized_bytes):
  66. """Deserialize bytes back to object tree. Uses ast.literal_eval (safe)."""
  67. serialized = codecs.decode(serialized_bytes, "utf-8")
  68. if '\x00' in serialized:
  69. raise ValueError(
  70. "The serpent data contains 0-bytes so it cannot be parsed by ast.literal_eval. Has it been corrupted?")
  71. try:
  72. gc.disable()
  73. return ast.literal_eval(serialized)
  74. finally:
  75. gc.enable()
  76. def load(file):
  77. """Deserialize bytes from a file back to object tree. Uses ast.literal_eval (safe)."""
  78. data = file.read()
  79. return loads(data)
  80. def _ser_OrderedDict(obj, serializer, outputstream, indentlevel):
  81. obj = {
  82. "__class__": "collections.OrderedDict" if serializer.module_in_classname else "OrderedDict",
  83. "items": list(obj.items())
  84. }
  85. serializer._serialize(obj, outputstream, indentlevel)
  86. def _ser_DictView(obj, serializer, outputstream, indentlevel):
  87. serializer.ser_builtins_list(obj, outputstream, indentlevel)
  88. _special_classes_registry = collections.OrderedDict() # must be insert-order preserving to make sure of proper precedence rules
  89. def _reset_special_classes_registry():
  90. _special_classes_registry.clear()
  91. _special_classes_registry[KeysView] = _ser_DictView
  92. _special_classes_registry[ValuesView] = _ser_DictView
  93. _special_classes_registry[ItemsView] = _ser_DictView
  94. _special_classes_registry[collections.OrderedDict] = _ser_OrderedDict
  95. def _ser_Enum(obj, serializer, outputstream, indentlevel):
  96. serializer._serialize(obj.value, outputstream, indentlevel)
  97. _special_classes_registry[enum.Enum] = _ser_Enum
  98. _reset_special_classes_registry()
  99. def unregister_class(clazz):
  100. """Unregister the specialcase serializer for the given class."""
  101. if clazz in _special_classes_registry:
  102. del _special_classes_registry[clazz]
  103. def register_class(clazz, serializer):
  104. """
  105. Register a special serializer function for objects of the given class.
  106. The function will be called with (object, serpent_serializer, outputstream, indentlevel) arguments.
  107. The function must write the serialized data to outputstream. It doesn't return a value.
  108. """
  109. _special_classes_registry[clazz] = serializer
  110. _repr_types = {str, int, bool, type(None)}
  111. _translate_types = {
  112. collections.deque: list,
  113. collections.UserDict: dict,
  114. collections.UserList: list,
  115. collections.UserString: str
  116. }
  117. _bytes_types = (bytes, bytearray, memoryview)
  118. def _translate_byte_type(t, data, bytes_repr):
  119. if bytes_repr:
  120. if t == bytes:
  121. return repr(data)
  122. elif t == bytearray:
  123. return repr(bytes(data))
  124. elif t == memoryview:
  125. return repr(bytes(data))
  126. else:
  127. raise TypeError("invalid bytes type")
  128. else:
  129. b64 = base64.b64encode(data)
  130. return repr({
  131. "data": b64 if type(b64) is str else b64.decode("ascii"),
  132. "encoding": "base64"
  133. })
  134. def tobytes(obj):
  135. """
  136. Utility function to convert obj back to actual bytes if it is a serpent-encoded bytes dictionary
  137. (a dict with base-64 encoded 'data' in it and 'encoding'='base64').
  138. If obj is already bytes or a byte-like type, return obj unmodified.
  139. Will raise TypeError if obj is none of the above.
  140. All this is not required if you called serpent with 'bytes_repr' set to True, since Serpent 1.40
  141. that can be used to directly encode bytes into the bytes literal value representation.
  142. That will be less efficient than the default base-64 encoding though, but it's a bit more convenient.
  143. """
  144. if isinstance(obj, _bytes_types):
  145. return obj
  146. if isinstance(obj, dict) and "data" in obj and obj.get("encoding") == "base64":
  147. try:
  148. return base64.b64decode(obj["data"])
  149. except TypeError:
  150. return base64.b64decode(obj["data"].encode("ascii")) # needed for certain older versions of pypy
  151. raise TypeError("argument is neither bytes nor serpent base64 encoded bytes dict")
  152. class Serializer(object):
  153. """
  154. Serialize an object tree to a byte stream.
  155. It is not thread-safe: make sure you're not making changes to the
  156. object tree that is being serialized, and don't use the same serializer
  157. across different threads.
  158. """
  159. dispatch = {}
  160. def __init__(self, indent=False, module_in_classname=False, bytes_repr=False):
  161. """
  162. Initialize the serializer.
  163. indent=indent the output over multiple lines (default=false)
  164. module_in_classname = include module prefix for class names or only use the class name itself
  165. bytes_repr = should the bytes literal value representation be used instead of base-64 encoding for bytes types?
  166. """
  167. self.indent = indent
  168. self.module_in_classname = module_in_classname
  169. self.serialized_obj_ids = set()
  170. self.special_classes_registry_copy = None
  171. self.maximum_level = min(sys.getrecursionlimit() // 5, 1000)
  172. self.bytes_repr = bytes_repr
  173. def serialize(self, obj):
  174. """Serialize the object tree to bytes."""
  175. self.special_classes_registry_copy = _special_classes_registry.copy() # make it thread safe
  176. header = "# serpent utf-8 python3.2\n"
  177. out = [header]
  178. try:
  179. gc.disable()
  180. self.serialized_obj_ids = set()
  181. self._serialize(obj, out, 0)
  182. finally:
  183. gc.enable()
  184. self.special_classes_registry_copy = None
  185. del self.serialized_obj_ids
  186. return "".join(out).encode("utf-8")
  187. _shortcut_dispatch_types = {float, complex, tuple, list, dict, set, frozenset}
  188. def _serialize(self, obj, out, level):
  189. if level > self.maximum_level:
  190. raise ValueError(
  191. "Object graph nesting too deep. Increase serializer.maximum_level if you think you need more, "
  192. " but this may cause a RecursionError instead if Python's recursion limit doesn't allow it.")
  193. t = type(obj)
  194. if t in _bytes_types:
  195. out.append(_translate_byte_type(t, obj, self.bytes_repr))
  196. return
  197. if t in _translate_types:
  198. obj = _translate_types[t](obj)
  199. t = type(obj)
  200. if t in _repr_types:
  201. out.append(repr(obj)) # just a simple repr() is enough for these objects
  202. return
  203. if t in self._shortcut_dispatch_types:
  204. # we shortcut these builtins directly to the dispatch function to avoid type lookup overhead below
  205. return self.dispatch[t](self, obj, out, level)
  206. # check special registered types:
  207. special_classes = self.special_classes_registry_copy
  208. for clazz in special_classes:
  209. if isinstance(obj, clazz):
  210. special_classes[clazz](obj, self, out, level)
  211. return
  212. # serialize dispatch
  213. try:
  214. func = self.dispatch[t]
  215. except KeyError:
  216. # walk the MRO until we find a base class we recognise
  217. for type_ in t.__mro__:
  218. if type_ in self.dispatch:
  219. func = self.dispatch[type_]
  220. break
  221. else:
  222. # fall back to the default class serializer
  223. func = Serializer.ser_default_class
  224. func(self, obj, out, level)
  225. def ser_builtins_float(self, float_obj, out, level):
  226. if math.isnan(float_obj):
  227. # there's no literal expression for a float NaN...
  228. out.append("{'__class__':'float','value':'nan'}")
  229. elif math.isinf(float_obj):
  230. # output a literal expression that overflows the float and results in +/-INF
  231. if float_obj > 0:
  232. out.append("1e30000")
  233. else:
  234. out.append("-1e30000")
  235. else:
  236. out.append(repr(float_obj))
  237. dispatch[float] = ser_builtins_float
  238. def ser_builtins_complex(self, complex_obj, out, level):
  239. out.append("(")
  240. self.ser_builtins_float(complex_obj.real, out, level)
  241. if complex_obj.imag >= 0:
  242. out.append("+")
  243. self.ser_builtins_float(complex_obj.imag, out, level)
  244. out.append("j)")
  245. dispatch[complex] = ser_builtins_complex
  246. def ser_builtins_tuple(self, tuple_obj, out, level):
  247. append = out.append
  248. serialize = self._serialize
  249. if self.indent and tuple_obj:
  250. indent_chars = " " * level
  251. indent_chars_inside = indent_chars + " "
  252. append("(\n")
  253. for elt in tuple_obj:
  254. append(indent_chars_inside)
  255. serialize(elt, out, level + 1)
  256. append(",\n")
  257. out[-1] = out[-1].rstrip() # remove the last \n
  258. if len(tuple_obj) > 1:
  259. del out[-1] # undo the last ,
  260. append("\n" + indent_chars + ")")
  261. else:
  262. append("(")
  263. for elt in tuple_obj:
  264. serialize(elt, out, level + 1)
  265. append(",")
  266. if len(tuple_obj) > 1:
  267. del out[-1] # undo the last ,
  268. append(")")
  269. dispatch[tuple] = ser_builtins_tuple
  270. def ser_builtins_list(self, list_obj, out, level):
  271. if id(list_obj) in self.serialized_obj_ids:
  272. raise ValueError("Circular reference detected (list)")
  273. self.serialized_obj_ids.add(id(list_obj))
  274. append = out.append
  275. serialize = self._serialize
  276. if self.indent and list_obj:
  277. indent_chars = " " * level
  278. indent_chars_inside = indent_chars + " "
  279. append("[\n")
  280. for elt in list_obj:
  281. append(indent_chars_inside)
  282. serialize(elt, out, level + 1)
  283. append(",\n")
  284. del out[-1] # remove the last ,\n
  285. append("\n" + indent_chars + "]")
  286. else:
  287. append("[")
  288. for elt in list_obj:
  289. serialize(elt, out, level + 1)
  290. append(",")
  291. if list_obj:
  292. del out[-1] # remove the last ,
  293. append("]")
  294. self.serialized_obj_ids.discard(id(list_obj))
  295. dispatch[list] = ser_builtins_list
  296. def _check_hashable_type(self, t):
  297. if t not in (bool, bytes, str, tuple) and not issubclass(t, numbers.Number):
  298. if issubclass(t, enum.Enum):
  299. return
  300. raise TypeError("one of the keys in a dict or set is not of a primitive hashable type: " +
  301. str(t) + ". Use simple types as keys or use a list or tuple as container.")
  302. def ser_builtins_dict(self, dict_obj, out, level):
  303. if id(dict_obj) in self.serialized_obj_ids:
  304. raise ValueError("Circular reference detected (dict)")
  305. self.serialized_obj_ids.add(id(dict_obj))
  306. append = out.append
  307. serialize = self._serialize
  308. if self.indent and dict_obj:
  309. indent_chars = " " * level
  310. indent_chars_inside = indent_chars + " "
  311. append("{\n")
  312. dict_items = dict_obj.items()
  313. try:
  314. sorted_items = sorted(dict_items)
  315. except TypeError: # can occur when elements can't be ordered (Python 3.x)
  316. sorted_items = dict_items
  317. for key, value in sorted_items:
  318. append(indent_chars_inside)
  319. self._check_hashable_type(type(key))
  320. serialize(key, out, level + 1)
  321. append(": ")
  322. serialize(value, out, level + 1)
  323. append(",\n")
  324. del out[-1] # remove last ,\n
  325. append("\n" + indent_chars + "}")
  326. else:
  327. append("{")
  328. for key, value in dict_obj.items():
  329. self._check_hashable_type(type(key))
  330. serialize(key, out, level + 1)
  331. append(":")
  332. serialize(value, out, level + 1)
  333. append(",")
  334. if dict_obj:
  335. del out[-1] # remove the last ,
  336. append("}")
  337. self.serialized_obj_ids.discard(id(dict_obj))
  338. dispatch[dict] = ser_builtins_dict
  339. def ser_builtins_set(self, set_obj, out, level):
  340. append = out.append
  341. serialize = self._serialize
  342. if self.indent and set_obj:
  343. indent_chars = " " * level
  344. indent_chars_inside = indent_chars + " "
  345. append("{\n")
  346. try:
  347. sorted_elts = sorted(set_obj)
  348. except TypeError: # can occur when elements can't be ordered (Python 3.x)
  349. sorted_elts = set_obj
  350. for elt in sorted_elts:
  351. append(indent_chars_inside)
  352. self._check_hashable_type(type(elt))
  353. serialize(elt, out, level + 1)
  354. append(",\n")
  355. del out[-1] # remove the last ,\n
  356. append("\n" + indent_chars + "}")
  357. elif set_obj:
  358. append("{")
  359. for elt in set_obj:
  360. self._check_hashable_type(type(elt))
  361. serialize(elt, out, level + 1)
  362. append(",")
  363. del out[-1] # remove the last ,
  364. append("}")
  365. else:
  366. # empty set literal doesn't exist unfortunately, replace with empty tuple
  367. self.ser_builtins_tuple((), out, level)
  368. dispatch[set] = ser_builtins_set
  369. def ser_builtins_frozenset(self, set_obj, out, level):
  370. self.ser_builtins_set(set_obj, out, level)
  371. dispatch[frozenset] = ser_builtins_set
  372. def ser_decimal_Decimal(self, decimal_obj, out, level):
  373. # decimal is serialized as a string to avoid losing precision
  374. out.append(repr(str(decimal_obj)))
  375. dispatch[decimal.Decimal] = ser_decimal_Decimal
  376. def ser_datetime_datetime(self, datetime_obj, out, level):
  377. out.append(repr(datetime_obj.isoformat()))
  378. dispatch[datetime.datetime] = ser_datetime_datetime
  379. def ser_datetime_date(self, date_obj, out, level):
  380. out.append(repr(date_obj.isoformat()))
  381. dispatch[datetime.date] = ser_datetime_date
  382. def ser_datetime_timedelta(self, timedelta_obj, out, level):
  383. secs = timedelta_obj.total_seconds()
  384. out.append(repr(secs))
  385. dispatch[datetime.timedelta] = ser_datetime_timedelta
  386. def ser_datetime_time(self, time_obj, out, level):
  387. out.append(repr(str(time_obj)))
  388. dispatch[datetime.time] = ser_datetime_time
  389. def ser_uuid_UUID(self, uuid_obj, out, level):
  390. out.append(repr(str(uuid_obj)))
  391. dispatch[uuid.UUID] = ser_uuid_UUID
  392. def ser_exception_class(self, exc_obj, out, level):
  393. value = {
  394. "__class__": self.get_class_name(exc_obj),
  395. "__exception__": True,
  396. "args": exc_obj.args,
  397. "attributes": vars(exc_obj) # add any custom attributes
  398. }
  399. self._serialize(value, out, level)
  400. dispatch[BaseException] = ser_exception_class
  401. def ser_array_array(self, array_obj, out, level):
  402. if array_obj.typecode == 'u':
  403. self._serialize(array_obj.tounicode(), out, level)
  404. else:
  405. self._serialize(array_obj.tolist(), out, level)
  406. dispatch[array.array] = ser_array_array
  407. def ser_default_class(self, obj, out, level):
  408. if id(obj) in self.serialized_obj_ids:
  409. raise ValueError("Circular reference detected (class)")
  410. self.serialized_obj_ids.add(id(obj))
  411. try:
  412. try:
  413. value = obj.__getstate__()
  414. if value is None and isinstance(obj, tuple):
  415. # collections.namedtuple specialcase (if it is not handled by the tuple serializer)
  416. value = {
  417. "__class__": self.get_class_name(obj),
  418. "items": list(obj._asdict().items())
  419. }
  420. if isinstance(value, dict):
  421. self.ser_builtins_dict(value, out, level)
  422. return
  423. except AttributeError:
  424. try:
  425. value = dict(vars(obj)) # make sure we can serialize anything that resembles a dict
  426. value["__class__"] = self.get_class_name(obj)
  427. except TypeError:
  428. if hasattr(obj, "__slots__"):
  429. # use the __slots__ instead of the vars dict
  430. value = {}
  431. for slot in obj.__slots__:
  432. value[slot] = getattr(obj, slot)
  433. value["__class__"] = self.get_class_name(obj)
  434. else:
  435. raise TypeError("don't know how to serialize class " +
  436. str(obj.__class__) + ". Give it vars() or an appropriate __getstate__")
  437. self._serialize(value, out, level)
  438. finally:
  439. self.serialized_obj_ids.discard(id(obj))
  440. def get_class_name(self, obj):
  441. if self.module_in_classname:
  442. return "%s.%s" % (obj.__class__.__module__, obj.__class__.__name__)
  443. else:
  444. return obj.__class__.__name__