protocol.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # Python implementation of low level MySQL client-server protocol
  2. # http://dev.mysql.com/doc/internals/en/client-server-protocol.html
  3. from .charset import MBLENGTH
  4. from .constants import FIELD_TYPE, SERVER_STATUS
  5. from . import err
  6. import struct
  7. import sys
  8. DEBUG = False
  9. NULL_COLUMN = 251
  10. UNSIGNED_CHAR_COLUMN = 251
  11. UNSIGNED_SHORT_COLUMN = 252
  12. UNSIGNED_INT24_COLUMN = 253
  13. UNSIGNED_INT64_COLUMN = 254
  14. def dump_packet(data): # pragma: no cover
  15. def printable(data):
  16. if 32 <= data < 127:
  17. return chr(data)
  18. return "."
  19. try:
  20. print("packet length:", len(data))
  21. for i in range(1, 7):
  22. f = sys._getframe(i)
  23. print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
  24. print("-" * 66)
  25. except ValueError:
  26. pass
  27. dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)]
  28. for d in dump_data:
  29. print(
  30. " ".join("{:02X}".format(x) for x in d)
  31. + " " * (16 - len(d))
  32. + " " * 2
  33. + "".join(printable(x) for x in d)
  34. )
  35. print("-" * 66)
  36. print()
  37. class MysqlPacket:
  38. """Representation of a MySQL response packet.
  39. Provides an interface for reading/parsing the packet results.
  40. """
  41. __slots__ = ("_position", "_data")
  42. def __init__(self, data, encoding):
  43. self._position = 0
  44. self._data = data
  45. def get_all_data(self):
  46. return self._data
  47. def read(self, size):
  48. """Read the first 'size' bytes in packet and advance cursor past them."""
  49. result = self._data[self._position : (self._position + size)]
  50. if len(result) != size:
  51. error = (
  52. "Result length not requested length:\n"
  53. "Expected=%s. Actual=%s. Position: %s. Data Length: %s"
  54. % (size, len(result), self._position, len(self._data))
  55. )
  56. if DEBUG:
  57. print(error)
  58. self.dump()
  59. raise AssertionError(error)
  60. self._position += size
  61. return result
  62. def read_all(self):
  63. """Read all remaining data in the packet.
  64. (Subsequent read() will return errors.)
  65. """
  66. result = self._data[self._position :]
  67. self._position = None # ensure no subsequent read()
  68. return result
  69. def advance(self, length):
  70. """Advance the cursor in data buffer 'length' bytes."""
  71. new_position = self._position + length
  72. if new_position < 0 or new_position > len(self._data):
  73. raise Exception(
  74. "Invalid advance amount (%s) for cursor. "
  75. "Position=%s" % (length, new_position)
  76. )
  77. self._position = new_position
  78. def rewind(self, position=0):
  79. """Set the position of the data buffer cursor to 'position'."""
  80. if position < 0 or position > len(self._data):
  81. raise Exception("Invalid position to rewind cursor to: %s." % position)
  82. self._position = position
  83. def get_bytes(self, position, length=1):
  84. """Get 'length' bytes starting at 'position'.
  85. Position is start of payload (first four packet header bytes are not
  86. included) starting at index '0'.
  87. No error checking is done. If requesting outside end of buffer
  88. an empty string (or string shorter than 'length') may be returned!
  89. """
  90. return self._data[position : (position + length)]
  91. def read_uint8(self):
  92. result = self._data[self._position]
  93. self._position += 1
  94. return result
  95. def read_uint16(self):
  96. result = struct.unpack_from("<H", self._data, self._position)[0]
  97. self._position += 2
  98. return result
  99. def read_uint24(self):
  100. low, high = struct.unpack_from("<HB", self._data, self._position)
  101. self._position += 3
  102. return low + (high << 16)
  103. def read_uint32(self):
  104. result = struct.unpack_from("<I", self._data, self._position)[0]
  105. self._position += 4
  106. return result
  107. def read_uint64(self):
  108. result = struct.unpack_from("<Q", self._data, self._position)[0]
  109. self._position += 8
  110. return result
  111. def read_string(self):
  112. end_pos = self._data.find(b"\0", self._position)
  113. if end_pos < 0:
  114. return None
  115. result = self._data[self._position : end_pos]
  116. self._position = end_pos + 1
  117. return result
  118. def read_length_encoded_integer(self):
  119. """Read a 'Length Coded Binary' number from the data buffer.
  120. Length coded numbers can be anywhere from 1 to 9 bytes depending
  121. on the value of the first byte.
  122. """
  123. c = self.read_uint8()
  124. if c == NULL_COLUMN:
  125. return None
  126. if c < UNSIGNED_CHAR_COLUMN:
  127. return c
  128. elif c == UNSIGNED_SHORT_COLUMN:
  129. return self.read_uint16()
  130. elif c == UNSIGNED_INT24_COLUMN:
  131. return self.read_uint24()
  132. elif c == UNSIGNED_INT64_COLUMN:
  133. return self.read_uint64()
  134. def read_length_coded_string(self):
  135. """Read a 'Length Coded String' from the data buffer.
  136. A 'Length Coded String' consists first of a length coded
  137. (unsigned, positive) integer represented in 1-9 bytes followed by
  138. that many bytes of binary data. (For example "cat" would be "3cat".)
  139. """
  140. length = self.read_length_encoded_integer()
  141. if length is None:
  142. return None
  143. return self.read(length)
  144. def read_struct(self, fmt):
  145. s = struct.Struct(fmt)
  146. result = s.unpack_from(self._data, self._position)
  147. self._position += s.size
  148. return result
  149. def is_ok_packet(self):
  150. # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
  151. return self._data[0] == 0 and len(self._data) >= 7
  152. def is_eof_packet(self):
  153. # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
  154. # Caution: \xFE may be LengthEncodedInteger.
  155. # If \xFE is LengthEncodedInteger header, 8bytes followed.
  156. return self._data[0] == 0xFE and len(self._data) < 9
  157. def is_auth_switch_request(self):
  158. # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
  159. return self._data[0] == 0xFE
  160. def is_extra_auth_data(self):
  161. # https://dev.mysql.com/doc/internals/en/successful-authentication.html
  162. return self._data[0] == 1
  163. def is_resultset_packet(self):
  164. field_count = self._data[0]
  165. return 1 <= field_count <= 250
  166. def is_load_local_packet(self):
  167. return self._data[0] == 0xFB
  168. def is_error_packet(self):
  169. return self._data[0] == 0xFF
  170. def check_error(self):
  171. if self.is_error_packet():
  172. self.raise_for_error()
  173. def raise_for_error(self):
  174. self.rewind()
  175. self.advance(1) # field_count == error (we already know that)
  176. errno = self.read_uint16()
  177. if DEBUG:
  178. print("errno =", errno)
  179. err.raise_mysql_exception(self._data)
  180. def dump(self):
  181. dump_packet(self._data)
  182. class FieldDescriptorPacket(MysqlPacket):
  183. """A MysqlPacket that represents a specific column's metadata in the result.
  184. Parsing is automatically done and the results are exported via public
  185. attributes on the class such as: db, table_name, name, length, type_code.
  186. """
  187. def __init__(self, data, encoding):
  188. MysqlPacket.__init__(self, data, encoding)
  189. self._parse_field_descriptor(encoding)
  190. def _parse_field_descriptor(self, encoding):
  191. """Parse the 'Field Descriptor' (Metadata) packet.
  192. This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
  193. """
  194. self.catalog = self.read_length_coded_string()
  195. self.db = self.read_length_coded_string()
  196. self.table_name = self.read_length_coded_string().decode(encoding)
  197. self.org_table = self.read_length_coded_string().decode(encoding)
  198. self.name = self.read_length_coded_string().decode(encoding)
  199. self.org_name = self.read_length_coded_string().decode(encoding)
  200. (
  201. self.charsetnr,
  202. self.length,
  203. self.type_code,
  204. self.flags,
  205. self.scale,
  206. ) = self.read_struct("<xHIBHBxx")
  207. # 'default' is a length coded binary and is still in the buffer?
  208. # not used for normal result sets...
  209. def description(self):
  210. """Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
  211. return (
  212. self.name,
  213. self.type_code,
  214. None, # TODO: display_length; should this be self.length?
  215. self.get_column_length(), # 'internal_size'
  216. self.get_column_length(), # 'precision' # TODO: why!?!?
  217. self.scale,
  218. self.flags % 2 == 0,
  219. )
  220. def get_column_length(self):
  221. if self.type_code == FIELD_TYPE.VAR_STRING:
  222. mblen = MBLENGTH.get(self.charsetnr, 1)
  223. return self.length // mblen
  224. return self.length
  225. def __str__(self):
  226. return "%s %r.%r.%r, type=%s, flags=%x" % (
  227. self.__class__,
  228. self.db,
  229. self.table_name,
  230. self.name,
  231. self.type_code,
  232. self.flags,
  233. )
  234. class OKPacketWrapper:
  235. """
  236. OK Packet Wrapper. It uses an existing packet object, and wraps
  237. around it, exposing useful variables while still providing access
  238. to the original packet objects variables and methods.
  239. """
  240. def __init__(self, from_packet):
  241. if not from_packet.is_ok_packet():
  242. raise ValueError(
  243. "Cannot create "
  244. + str(self.__class__.__name__)
  245. + " object from invalid packet type"
  246. )
  247. self.packet = from_packet
  248. self.packet.advance(1)
  249. self.affected_rows = self.packet.read_length_encoded_integer()
  250. self.insert_id = self.packet.read_length_encoded_integer()
  251. self.server_status, self.warning_count = self.read_struct("<HH")
  252. self.message = self.packet.read_all()
  253. self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
  254. def __getattr__(self, key):
  255. return getattr(self.packet, key)
  256. class EOFPacketWrapper:
  257. """
  258. EOF Packet Wrapper. It uses an existing packet object, and wraps
  259. around it, exposing useful variables while still providing access
  260. to the original packet objects variables and methods.
  261. """
  262. def __init__(self, from_packet):
  263. if not from_packet.is_eof_packet():
  264. raise ValueError(
  265. f"Cannot create '{self.__class__}' object from invalid packet type"
  266. )
  267. self.packet = from_packet
  268. self.warning_count, self.server_status = self.packet.read_struct("<xhh")
  269. if DEBUG:
  270. print("server_status=", self.server_status)
  271. self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
  272. def __getattr__(self, key):
  273. return getattr(self.packet, key)
  274. class LoadLocalPacketWrapper:
  275. """
  276. Load Local Packet Wrapper. It uses an existing packet object, and wraps
  277. around it, exposing useful variables while still providing access
  278. to the original packet objects variables and methods.
  279. """
  280. def __init__(self, from_packet):
  281. if not from_packet.is_load_local_packet():
  282. raise ValueError(
  283. f"Cannot create '{self.__class__}' object from invalid packet type"
  284. )
  285. self.packet = from_packet
  286. self.filename = self.packet.get_all_data()[1:]
  287. if DEBUG:
  288. print("filename=", self.filename)
  289. def __getattr__(self, key):
  290. return getattr(self.packet, key)