converters.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import datetime
  2. from decimal import Decimal
  3. import re
  4. import time
  5. from .err import ProgrammingError
  6. from .constants import FIELD_TYPE
  7. def escape_item(val, charset, mapping=None):
  8. if mapping is None:
  9. mapping = encoders
  10. encoder = mapping.get(type(val))
  11. # Fallback to default when no encoder found
  12. if not encoder:
  13. try:
  14. encoder = mapping[str]
  15. except KeyError:
  16. raise TypeError("no default type converter defined")
  17. if encoder in (escape_dict, escape_sequence):
  18. val = encoder(val, charset, mapping)
  19. else:
  20. val = encoder(val, mapping)
  21. return val
  22. def escape_dict(val, charset, mapping=None):
  23. n = {}
  24. for k, v in val.items():
  25. quoted = escape_item(v, charset, mapping)
  26. n[k] = quoted
  27. return n
  28. def escape_sequence(val, charset, mapping=None):
  29. n = []
  30. for item in val:
  31. quoted = escape_item(item, charset, mapping)
  32. n.append(quoted)
  33. return "(" + ",".join(n) + ")"
  34. def escape_set(val, charset, mapping=None):
  35. return ",".join([escape_item(x, charset, mapping) for x in val])
  36. def escape_bool(value, mapping=None):
  37. return str(int(value))
  38. def escape_int(value, mapping=None):
  39. return str(value)
  40. def escape_float(value, mapping=None):
  41. s = repr(value)
  42. if s in ("inf", "nan"):
  43. raise ProgrammingError("%s can not be used with MySQL" % s)
  44. if "e" not in s:
  45. s += "e0"
  46. return s
  47. _escape_table = [chr(x) for x in range(128)]
  48. _escape_table[0] = "\\0"
  49. _escape_table[ord("\\")] = "\\\\"
  50. _escape_table[ord("\n")] = "\\n"
  51. _escape_table[ord("\r")] = "\\r"
  52. _escape_table[ord("\032")] = "\\Z"
  53. _escape_table[ord('"')] = '\\"'
  54. _escape_table[ord("'")] = "\\'"
  55. def escape_string(value, mapping=None):
  56. """escapes *value* without adding quote.
  57. Value should be unicode
  58. """
  59. return value.translate(_escape_table)
  60. def escape_bytes_prefixed(value, mapping=None):
  61. return "_binary'%s'" % value.decode("ascii", "surrogateescape").translate(
  62. _escape_table
  63. )
  64. def escape_bytes(value, mapping=None):
  65. return "'%s'" % value.decode("ascii", "surrogateescape").translate(_escape_table)
  66. def escape_str(value, mapping=None):
  67. return "'%s'" % escape_string(str(value), mapping)
  68. def escape_None(value, mapping=None):
  69. return "NULL"
  70. def escape_timedelta(obj, mapping=None):
  71. seconds = int(obj.seconds) % 60
  72. minutes = int(obj.seconds // 60) % 60
  73. hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
  74. if obj.microseconds:
  75. fmt = "'{0:02d}:{1:02d}:{2:02d}.{3:06d}'"
  76. else:
  77. fmt = "'{0:02d}:{1:02d}:{2:02d}'"
  78. return fmt.format(hours, minutes, seconds, obj.microseconds)
  79. def escape_time(obj, mapping=None):
  80. if obj.microsecond:
  81. fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
  82. else:
  83. fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
  84. return fmt.format(obj)
  85. def escape_datetime(obj, mapping=None):
  86. if obj.microsecond:
  87. fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
  88. else:
  89. fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'"
  90. return fmt.format(obj)
  91. def escape_date(obj, mapping=None):
  92. fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
  93. return fmt.format(obj)
  94. def escape_struct_time(obj, mapping=None):
  95. return escape_datetime(datetime.datetime(*obj[:6]))
  96. def Decimal2Literal(o, d):
  97. return format(o, "f")
  98. def _convert_second_fraction(s):
  99. if not s:
  100. return 0
  101. # Pad zeros to ensure the fraction length in microseconds
  102. s = s.ljust(6, "0")
  103. return int(s[:6])
  104. DATETIME_RE = re.compile(
  105. r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?"
  106. )
  107. def convert_datetime(obj):
  108. """Returns a DATETIME or TIMESTAMP column value as a datetime object:
  109. >>> datetime_or_None('2007-02-25 23:06:20')
  110. datetime.datetime(2007, 2, 25, 23, 6, 20)
  111. >>> datetime_or_None('2007-02-25T23:06:20')
  112. datetime.datetime(2007, 2, 25, 23, 6, 20)
  113. Illegal values are returned as None:
  114. >>> datetime_or_None('2007-02-31T23:06:20') is None
  115. True
  116. >>> datetime_or_None('0000-00-00 00:00:00') is None
  117. True
  118. """
  119. if isinstance(obj, (bytes, bytearray)):
  120. obj = obj.decode("ascii")
  121. m = DATETIME_RE.match(obj)
  122. if not m:
  123. return convert_date(obj)
  124. try:
  125. groups = list(m.groups())
  126. groups[-1] = _convert_second_fraction(groups[-1])
  127. return datetime.datetime(*[int(x) for x in groups])
  128. except ValueError:
  129. return convert_date(obj)
  130. TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
  131. def convert_timedelta(obj):
  132. """Returns a TIME column as a timedelta object:
  133. >>> timedelta_or_None('25:06:17')
  134. datetime.timedelta(1, 3977)
  135. >>> timedelta_or_None('-25:06:17')
  136. datetime.timedelta(-2, 83177)
  137. Illegal values are returned as None:
  138. >>> timedelta_or_None('random crap') is None
  139. True
  140. Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
  141. can accept values as (+|-)DD HH:MM:SS. The latter format will not
  142. be parsed correctly by this function.
  143. """
  144. if isinstance(obj, (bytes, bytearray)):
  145. obj = obj.decode("ascii")
  146. m = TIMEDELTA_RE.match(obj)
  147. if not m:
  148. return obj
  149. try:
  150. groups = list(m.groups())
  151. groups[-1] = _convert_second_fraction(groups[-1])
  152. negate = -1 if groups[0] else 1
  153. hours, minutes, seconds, microseconds = groups[1:]
  154. tdelta = (
  155. datetime.timedelta(
  156. hours=int(hours),
  157. minutes=int(minutes),
  158. seconds=int(seconds),
  159. microseconds=int(microseconds),
  160. )
  161. * negate
  162. )
  163. return tdelta
  164. except ValueError:
  165. return obj
  166. TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
  167. def convert_time(obj):
  168. """Returns a TIME column as a time object:
  169. >>> time_or_None('15:06:17')
  170. datetime.time(15, 6, 17)
  171. Illegal values are returned as None:
  172. >>> time_or_None('-25:06:17') is None
  173. True
  174. >>> time_or_None('random crap') is None
  175. True
  176. Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
  177. can accept values as (+|-)DD HH:MM:SS. The latter format will not
  178. be parsed correctly by this function.
  179. Also note that MySQL's TIME column corresponds more closely to
  180. Python's timedelta and not time. However if you want TIME columns
  181. to be treated as time-of-day and not a time offset, then you can
  182. use set this function as the converter for FIELD_TYPE.TIME.
  183. """
  184. if isinstance(obj, (bytes, bytearray)):
  185. obj = obj.decode("ascii")
  186. m = TIME_RE.match(obj)
  187. if not m:
  188. return obj
  189. try:
  190. groups = list(m.groups())
  191. groups[-1] = _convert_second_fraction(groups[-1])
  192. hours, minutes, seconds, microseconds = groups
  193. return datetime.time(
  194. hour=int(hours),
  195. minute=int(minutes),
  196. second=int(seconds),
  197. microsecond=int(microseconds),
  198. )
  199. except ValueError:
  200. return obj
  201. def convert_date(obj):
  202. """Returns a DATE column as a date object:
  203. >>> date_or_None('2007-02-26')
  204. datetime.date(2007, 2, 26)
  205. Illegal values are returned as None:
  206. >>> date_or_None('2007-02-31') is None
  207. True
  208. >>> date_or_None('0000-00-00') is None
  209. True
  210. """
  211. if isinstance(obj, (bytes, bytearray)):
  212. obj = obj.decode("ascii")
  213. try:
  214. return datetime.date(*[int(x) for x in obj.split("-", 2)])
  215. except ValueError:
  216. return obj
  217. def through(x):
  218. return x
  219. # def convert_bit(b):
  220. # b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
  221. # return struct.unpack(">Q", b)[0]
  222. #
  223. # the snippet above is right, but MySQLdb doesn't process bits,
  224. # so we shouldn't either
  225. convert_bit = through
  226. encoders = {
  227. bool: escape_bool,
  228. int: escape_int,
  229. float: escape_float,
  230. str: escape_str,
  231. bytes: escape_bytes,
  232. tuple: escape_sequence,
  233. list: escape_sequence,
  234. set: escape_sequence,
  235. frozenset: escape_sequence,
  236. dict: escape_dict,
  237. type(None): escape_None,
  238. datetime.date: escape_date,
  239. datetime.datetime: escape_datetime,
  240. datetime.timedelta: escape_timedelta,
  241. datetime.time: escape_time,
  242. time.struct_time: escape_struct_time,
  243. Decimal: Decimal2Literal,
  244. }
  245. decoders = {
  246. FIELD_TYPE.BIT: convert_bit,
  247. FIELD_TYPE.TINY: int,
  248. FIELD_TYPE.SHORT: int,
  249. FIELD_TYPE.LONG: int,
  250. FIELD_TYPE.FLOAT: float,
  251. FIELD_TYPE.DOUBLE: float,
  252. FIELD_TYPE.LONGLONG: int,
  253. FIELD_TYPE.INT24: int,
  254. FIELD_TYPE.YEAR: int,
  255. FIELD_TYPE.TIMESTAMP: convert_datetime,
  256. FIELD_TYPE.DATETIME: convert_datetime,
  257. FIELD_TYPE.TIME: convert_timedelta,
  258. FIELD_TYPE.DATE: convert_date,
  259. FIELD_TYPE.BLOB: through,
  260. FIELD_TYPE.TINY_BLOB: through,
  261. FIELD_TYPE.MEDIUM_BLOB: through,
  262. FIELD_TYPE.LONG_BLOB: through,
  263. FIELD_TYPE.STRING: through,
  264. FIELD_TYPE.VAR_STRING: through,
  265. FIELD_TYPE.VARCHAR: through,
  266. FIELD_TYPE.DECIMAL: Decimal,
  267. FIELD_TYPE.NEWDECIMAL: Decimal,
  268. }
  269. # for MySQLdb compatibility
  270. conversions = encoders.copy()
  271. conversions.update(decoders)
  272. Thing2Literal = escape_str