12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367 |
- # Python implementation of the MySQL client-server protocol
- # http://dev.mysql.com/doc/internals/en/client-server-protocol.html
- # Error codes:
- # https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
- import errno
- import os
- import socket
- import struct
- import sys
- import traceback
- import warnings
- from . import _auth
- from .charset import charset_by_name, charset_by_id
- from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
- from . import converters
- from .cursors import Cursor
- from .optionfile import Parser
- from .protocol import (
- dump_packet,
- MysqlPacket,
- FieldDescriptorPacket,
- OKPacketWrapper,
- EOFPacketWrapper,
- LoadLocalPacketWrapper,
- )
- from . import err, VERSION_STRING
- try:
- import ssl
- SSL_ENABLED = True
- except ImportError:
- ssl = None
- SSL_ENABLED = False
- try:
- import getpass
- DEFAULT_USER = getpass.getuser()
- del getpass
- except (ImportError, KeyError):
- # KeyError occurs when there's no entry in OS database for a current user.
- DEFAULT_USER = None
- DEBUG = False
- TEXT_TYPES = {
- FIELD_TYPE.BIT,
- FIELD_TYPE.BLOB,
- FIELD_TYPE.LONG_BLOB,
- FIELD_TYPE.MEDIUM_BLOB,
- FIELD_TYPE.STRING,
- FIELD_TYPE.TINY_BLOB,
- FIELD_TYPE.VAR_STRING,
- FIELD_TYPE.VARCHAR,
- FIELD_TYPE.GEOMETRY,
- }
- DEFAULT_CHARSET = "utf8mb4"
- MAX_PACKET_LEN = 2 ** 24 - 1
- def _pack_int24(n):
- return struct.pack("<I", n)[:3]
- # https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
- def _lenenc_int(i):
- if i < 0:
- raise ValueError(
- "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i
- )
- elif i < 0xFB:
- return bytes([i])
- elif i < (1 << 16):
- return b"\xfc" + struct.pack("<H", i)
- elif i < (1 << 24):
- return b"\xfd" + struct.pack("<I", i)[:3]
- elif i < (1 << 64):
- return b"\xfe" + struct.pack("<Q", i)
- else:
- raise ValueError(
- "Encoding %x is larger than %x - no representation in LengthEncodedInteger"
- % (i, (1 << 64))
- )
- class Connection:
- """
- Representation of a socket with a mysql server.
- The proper way to get an instance of this class is to call
- connect().
- Establish a connection to the MySQL database. Accepts several
- arguments:
- :param host: Host where the database server is located
- :param user: Username to log in as
- :param password: Password to use.
- :param database: Database to use, None to not use a particular one.
- :param port: MySQL port to use, default is usually OK. (default: 3306)
- :param bind_address: When the client has multiple network interfaces, specify
- the interface from which to connect to the host. Argument can be
- a hostname or an IP address.
- :param unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
- :param read_timeout: The timeout for reading from the connection in seconds (default: None - no timeout)
- :param write_timeout: The timeout for writing to the connection in seconds (default: None - no timeout)
- :param charset: Charset you want to use.
- :param sql_mode: Default SQL_MODE to use.
- :param read_default_file:
- Specifies my.cnf file to read these parameters from under the [client] section.
- :param conv:
- Conversion dictionary to use instead of the default one.
- This is used to provide custom marshalling and unmarshalling of types.
- See converters.
- :param use_unicode:
- Whether or not to default to unicode strings.
- This option defaults to true.
- :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
- :param cursorclass: Custom cursor class to use.
- :param init_command: Initial SQL statement to run when connection is established.
- :param connect_timeout: Timeout before throwing an exception when connecting.
- (default: 10, min: 1, max: 31536000)
- :param ssl:
- A dict of arguments similar to mysql_ssl_set()'s parameters.
- :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate
- :param ssl_cert: Path to the file that contains a PEM-formatted client certificate
- :param ssl_disabled: A boolean value that disables usage of TLS
- :param ssl_key: Path to the file that contains a PEM-formatted private key for the client certificate
- :param ssl_verify_cert: Set to true to check the validity of server certificates
- :param ssl_verify_identity: Set to true to check the server's identity
- :param read_default_group: Group to read from in the configuration file.
- :param autocommit: Autocommit mode. None means use server default. (default: False)
- :param local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
- :param max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
- Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
- :param defer_connect: Don't explicitly connect on construction - wait for connect call.
- (default: False)
- :param auth_plugin_map: A dict of plugin names to a class that processes that plugin.
- The class will take the Connection object as the argument to the constructor.
- The class needs an authenticate method taking an authentication packet as
- an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
- (if no authenticate method) for returning a string from the user. (experimental)
- :param server_public_key: SHA256 authentication plugin public key value. (default: None)
- :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
- :param compress: Not supported
- :param named_pipe: Not supported
- :param db: **DEPRECATED** Alias for database.
- :param passwd: **DEPRECATED** Alias for password.
- See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_ in the
- specification.
- """
- _sock = None
- _auth_plugin_name = ""
- _closed = False
- _secure = False
- def __init__(
- self,
- *,
- user=None, # The first four arguments is based on DB-API 2.0 recommendation.
- password="",
- host=None,
- database=None,
- unix_socket=None,
- port=0,
- charset="",
- sql_mode=None,
- read_default_file=None,
- conv=None,
- use_unicode=True,
- client_flag=0,
- cursorclass=Cursor,
- init_command=None,
- connect_timeout=10,
- read_default_group=None,
- autocommit=False,
- local_infile=False,
- max_allowed_packet=16 * 1024 * 1024,
- defer_connect=False,
- auth_plugin_map=None,
- read_timeout=None,
- write_timeout=None,
- bind_address=None,
- binary_prefix=False,
- program_name=None,
- server_public_key=None,
- ssl=None,
- ssl_ca=None,
- ssl_cert=None,
- ssl_disabled=None,
- ssl_key=None,
- ssl_verify_cert=None,
- ssl_verify_identity=None,
- compress=None, # not supported
- named_pipe=None, # not supported
- passwd=None, # deprecated
- db=None, # deprecated
- ):
- if db is not None and database is None:
- # We will raise warining in 2022 or later.
- # See https://github.com/PyMySQL/PyMySQL/issues/939
- # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
- database = db
- if passwd is not None and not password:
- # We will raise warining in 2022 or later.
- # See https://github.com/PyMySQL/PyMySQL/issues/939
- # warnings.warn(
- # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
- # )
- password = passwd
- if compress or named_pipe:
- raise NotImplementedError(
- "compress and named_pipe arguments are not supported"
- )
- self._local_infile = bool(local_infile)
- if self._local_infile:
- client_flag |= CLIENT.LOCAL_FILES
- if read_default_group and not read_default_file:
- if sys.platform.startswith("win"):
- read_default_file = "c:\\my.ini"
- else:
- read_default_file = "/etc/my.cnf"
- if read_default_file:
- if not read_default_group:
- read_default_group = "client"
- cfg = Parser()
- cfg.read(os.path.expanduser(read_default_file))
- def _config(key, arg):
- if arg:
- return arg
- try:
- return cfg.get(read_default_group, key)
- except Exception:
- return arg
- user = _config("user", user)
- password = _config("password", password)
- host = _config("host", host)
- database = _config("database", database)
- unix_socket = _config("socket", unix_socket)
- port = int(_config("port", port))
- bind_address = _config("bind-address", bind_address)
- charset = _config("default-character-set", charset)
- if not ssl:
- ssl = {}
- if isinstance(ssl, dict):
- for key in ["ca", "capath", "cert", "key", "cipher"]:
- value = _config("ssl-" + key, ssl.get(key))
- if value:
- ssl[key] = value
- self.ssl = False
- if not ssl_disabled:
- if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity:
- ssl = {
- "ca": ssl_ca,
- "check_hostname": bool(ssl_verify_identity),
- "verify_mode": ssl_verify_cert
- if ssl_verify_cert is not None
- else False,
- }
- if ssl_cert is not None:
- ssl["cert"] = ssl_cert
- if ssl_key is not None:
- ssl["key"] = ssl_key
- if ssl:
- if not SSL_ENABLED:
- raise NotImplementedError("ssl module not found")
- self.ssl = True
- client_flag |= CLIENT.SSL
- self.ctx = self._create_ssl_ctx(ssl)
- self.host = host or "localhost"
- self.port = port or 3306
- if type(self.port) is not int:
- raise ValueError("port should be of type int")
- self.user = user or DEFAULT_USER
- self.password = password or b""
- if isinstance(self.password, str):
- self.password = self.password.encode("latin1")
- self.db = database
- self.unix_socket = unix_socket
- self.bind_address = bind_address
- if not (0 < connect_timeout <= 31536000):
- raise ValueError("connect_timeout should be >0 and <=31536000")
- self.connect_timeout = connect_timeout or None
- if read_timeout is not None and read_timeout <= 0:
- raise ValueError("read_timeout should be > 0")
- self._read_timeout = read_timeout
- if write_timeout is not None and write_timeout <= 0:
- raise ValueError("write_timeout should be > 0")
- self._write_timeout = write_timeout
- self.charset = charset or DEFAULT_CHARSET
- self.use_unicode = use_unicode
- self.encoding = charset_by_name(self.charset).encoding
- client_flag |= CLIENT.CAPABILITIES
- if self.db:
- client_flag |= CLIENT.CONNECT_WITH_DB
- self.client_flag = client_flag
- self.cursorclass = cursorclass
- self._result = None
- self._affected_rows = 0
- self.host_info = "Not connected"
- # specified autocommit mode. None means use server default.
- self.autocommit_mode = autocommit
- if conv is None:
- conv = converters.conversions
- # Need for MySQLdb compatibility.
- self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
- self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
- self.sql_mode = sql_mode
- self.init_command = init_command
- self.max_allowed_packet = max_allowed_packet
- self._auth_plugin_map = auth_plugin_map or {}
- self._binary_prefix = binary_prefix
- self.server_public_key = server_public_key
- self._connect_attrs = {
- "_client_name": "pymysql",
- "_pid": str(os.getpid()),
- "_client_version": VERSION_STRING,
- }
- if program_name:
- self._connect_attrs["program_name"] = program_name
- if defer_connect:
- self._sock = None
- else:
- self.connect()
- def __enter__(self):
- return self
- def __exit__(self, *exc_info):
- del exc_info
- self.close()
- def _create_ssl_ctx(self, sslp):
- if isinstance(sslp, ssl.SSLContext):
- return sslp
- ca = sslp.get("ca")
- capath = sslp.get("capath")
- hasnoca = ca is None and capath is None
- ctx = ssl.create_default_context(cafile=ca, capath=capath)
- ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
- verify_mode_value = sslp.get("verify_mode")
- if verify_mode_value is None:
- ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
- elif isinstance(verify_mode_value, bool):
- ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
- else:
- if isinstance(verify_mode_value, str):
- verify_mode_value = verify_mode_value.lower()
- if verify_mode_value in ("none", "0", "false", "no"):
- ctx.verify_mode = ssl.CERT_NONE
- elif verify_mode_value == "optional":
- ctx.verify_mode = ssl.CERT_OPTIONAL
- elif verify_mode_value in ("required", "1", "true", "yes"):
- ctx.verify_mode = ssl.CERT_REQUIRED
- else:
- ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
- if "cert" in sslp:
- ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
- if "cipher" in sslp:
- ctx.set_ciphers(sslp["cipher"])
- ctx.options |= ssl.OP_NO_SSLv2
- ctx.options |= ssl.OP_NO_SSLv3
- return ctx
- def close(self):
- """
- Send the quit message and close the socket.
- See `Connection.close() <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
- in the specification.
- :raise Error: If the connection is already closed.
- """
- if self._closed:
- raise err.Error("Already closed")
- self._closed = True
- if self._sock is None:
- return
- send_data = struct.pack("<iB", 1, COMMAND.COM_QUIT)
- try:
- self._write_bytes(send_data)
- except Exception:
- pass
- finally:
- self._force_close()
- @property
- def open(self):
- """Return True if the connection is open"""
- return self._sock is not None
- def _force_close(self):
- """Close connection without QUIT message"""
- if self._sock:
- try:
- self._sock.close()
- except: # noqa
- pass
- self._sock = None
- self._rfile = None
- __del__ = _force_close
- def autocommit(self, value):
- self.autocommit_mode = bool(value)
- current = self.get_autocommit()
- if value != current:
- self._send_autocommit_mode()
- def get_autocommit(self):
- return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
- def _read_ok_packet(self):
- pkt = self._read_packet()
- if not pkt.is_ok_packet():
- raise err.OperationalError(2014, "Command Out of Sync")
- ok = OKPacketWrapper(pkt)
- self.server_status = ok.server_status
- return ok
- def _send_autocommit_mode(self):
- """Set whether or not to commit after every execute()"""
- self._execute_command(
- COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode)
- )
- self._read_ok_packet()
- def begin(self):
- """Begin transaction."""
- self._execute_command(COMMAND.COM_QUERY, "BEGIN")
- self._read_ok_packet()
- def commit(self):
- """
- Commit changes to stable storage.
- See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
- in the specification.
- """
- self._execute_command(COMMAND.COM_QUERY, "COMMIT")
- self._read_ok_packet()
- def rollback(self):
- """
- Roll back the current transaction.
- See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
- in the specification.
- """
- self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
- self._read_ok_packet()
- def show_warnings(self):
- """Send the "SHOW WARNINGS" SQL command."""
- self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
- result = MySQLResult(self)
- result.read()
- return result.rows
- def select_db(self, db):
- """
- Set current db.
- :param db: The name of the db.
- """
- self._execute_command(COMMAND.COM_INIT_DB, db)
- self._read_ok_packet()
- def escape(self, obj, mapping=None):
- """Escape whatever value you pass to it.
- Non-standard, for internal use; do not use this in your applications.
- """
- if isinstance(obj, str):
- return "'" + self.escape_string(obj) + "'"
- if isinstance(obj, (bytes, bytearray)):
- ret = self._quote_bytes(obj)
- if self._binary_prefix:
- ret = "_binary" + ret
- return ret
- return converters.escape_item(obj, self.charset, mapping=mapping)
- def literal(self, obj):
- """Alias for escape()
- Non-standard, for internal use; do not use this in your applications.
- """
- return self.escape(obj, self.encoders)
- def escape_string(self, s):
- if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
- return s.replace("'", "''")
- return converters.escape_string(s)
- def _quote_bytes(self, s):
- if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
- return "'%s'" % (s.replace(b"'", b"''").decode("ascii", "surrogateescape"),)
- return converters.escape_bytes(s)
- def cursor(self, cursor=None):
- """
- Create a new cursor to execute queries with.
- :param cursor: The type of cursor to create; one of :py:class:`Cursor`,
- :py:class:`SSCursor`, :py:class:`DictCursor`, or :py:class:`SSDictCursor`.
- None means use Cursor.
- """
- if cursor:
- return cursor(self)
- return self.cursorclass(self)
- # The following methods are INTERNAL USE ONLY (called from Cursor)
- def query(self, sql, unbuffered=False):
- # if DEBUG:
- # print("DEBUG: sending query:", sql)
- if isinstance(sql, str):
- sql = sql.encode(self.encoding, "surrogateescape")
- self._execute_command(COMMAND.COM_QUERY, sql)
- self._affected_rows = self._read_query_result(unbuffered=unbuffered)
- return self._affected_rows
- def next_result(self, unbuffered=False):
- self._affected_rows = self._read_query_result(unbuffered=unbuffered)
- return self._affected_rows
- def affected_rows(self):
- return self._affected_rows
- def kill(self, thread_id):
- arg = struct.pack("<I", thread_id)
- self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
- return self._read_ok_packet()
- def ping(self, reconnect=True):
- """
- Check if the server is alive.
- :param reconnect: If the connection is closed, reconnect.
- :raise Error: If the connection is closed and reconnect=False.
- """
- if self._sock is None:
- if reconnect:
- self.connect()
- reconnect = False
- else:
- raise err.Error("Already closed")
- try:
- self._execute_command(COMMAND.COM_PING, "")
- self._read_ok_packet()
- except Exception:
- if reconnect:
- self.connect()
- self.ping(False)
- else:
- raise
- def set_charset(self, charset):
- # Make sure charset is supported.
- encoding = charset_by_name(charset).encoding
- self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s" % self.escape(charset))
- self._read_packet()
- self.charset = charset
- self.encoding = encoding
- def connect(self, sock=None):
- self._closed = False
- try:
- if sock is None:
- if self.unix_socket:
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- sock.settimeout(self.connect_timeout)
- sock.connect(self.unix_socket)
- self.host_info = "Localhost via UNIX socket"
- self._secure = True
- if DEBUG:
- print("connected using unix_socket")
- else:
- kwargs = {}
- if self.bind_address is not None:
- kwargs["source_address"] = (self.bind_address, 0)
- while True:
- try:
- sock = socket.create_connection(
- (self.host, self.port), self.connect_timeout, **kwargs
- )
- break
- except (OSError, IOError) as e:
- if e.errno == errno.EINTR:
- continue
- raise
- self.host_info = "socket %s:%d" % (self.host, self.port)
- if DEBUG:
- print("connected using socket")
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- sock.settimeout(None)
- self._sock = sock
- self._rfile = sock.makefile("rb")
- self._next_seq_id = 0
- self._get_server_information()
- self._request_authentication()
- if self.sql_mode is not None:
- c = self.cursor()
- c.execute("SET sql_mode=%s", (self.sql_mode,))
- if self.init_command is not None:
- c = self.cursor()
- c.execute(self.init_command)
- c.close()
- self.commit()
- if self.autocommit_mode is not None:
- self.autocommit(self.autocommit_mode)
- except BaseException as e:
- self._rfile = None
- if sock is not None:
- try:
- sock.close()
- except: # noqa
- pass
- if isinstance(e, (OSError, IOError, socket.error)):
- exc = err.OperationalError(
- 2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e)
- )
- # Keep original exception and traceback to investigate error.
- exc.original_exception = e
- exc.traceback = traceback.format_exc()
- if DEBUG:
- print(exc.traceback)
- raise exc
- # If e is neither DatabaseError or IOError, It's a bug.
- # But raising AssertionError hides original error.
- # So just reraise it.
- raise
- def write_packet(self, payload):
- """Writes an entire "mysql packet" in its entirety to the network
- adding its length and sequence number.
- """
- # Internal note: when you build packet manually and calls _write_bytes()
- # directly, you should set self._next_seq_id properly.
- data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
- if DEBUG:
- dump_packet(data)
- self._write_bytes(data)
- self._next_seq_id = (self._next_seq_id + 1) % 256
- def _read_packet(self, packet_type=MysqlPacket):
- """Read an entire "mysql packet" in its entirety from the network
- and return a MysqlPacket type that represents the results.
- :raise OperationalError: If the connection to the MySQL server is lost.
- :raise InternalError: If the packet sequence number is wrong.
- """
- buff = bytearray()
- while True:
- packet_header = self._read_bytes(4)
- # if DEBUG: dump_packet(packet_header)
- btrl, btrh, packet_number = struct.unpack("<HBB", packet_header)
- bytes_to_read = btrl + (btrh << 16)
- if packet_number != self._next_seq_id:
- self._force_close()
- if packet_number == 0:
- # MariaDB sends error packet with seqno==0 when shutdown
- raise err.OperationalError(
- CR.CR_SERVER_LOST,
- "Lost connection to MySQL server during query",
- )
- raise err.InternalError(
- "Packet sequence number wrong - got %d expected %d"
- % (packet_number, self._next_seq_id)
- )
- self._next_seq_id = (self._next_seq_id + 1) % 256
- recv_data = self._read_bytes(bytes_to_read)
- if DEBUG:
- dump_packet(recv_data)
- buff += recv_data
- # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
- if bytes_to_read == 0xFFFFFF:
- continue
- if bytes_to_read < MAX_PACKET_LEN:
- break
- packet = packet_type(bytes(buff), self.encoding)
- if packet.is_error_packet():
- if self._result is not None and self._result.unbuffered_active is True:
- self._result.unbuffered_active = False
- packet.raise_for_error()
- return packet
- def _read_bytes(self, num_bytes):
- self._sock.settimeout(self._read_timeout)
- while True:
- try:
- data = self._rfile.read(num_bytes)
- break
- except (IOError, OSError) as e:
- if e.errno == errno.EINTR:
- continue
- self._force_close()
- raise err.OperationalError(
- CR.CR_SERVER_LOST,
- "Lost connection to MySQL server during query (%s)" % (e,),
- )
- except BaseException:
- # Don't convert unknown exception to MySQLError.
- self._force_close()
- raise
- if len(data) < num_bytes:
- self._force_close()
- raise err.OperationalError(
- CR.CR_SERVER_LOST, "Lost connection to MySQL server during query"
- )
- return data
- def _write_bytes(self, data):
- self._sock.settimeout(self._write_timeout)
- try:
- self._sock.sendall(data)
- except IOError as e:
- self._force_close()
- raise err.OperationalError(
- CR.CR_SERVER_GONE_ERROR, "MySQL server has gone away (%r)" % (e,)
- )
- def _read_query_result(self, unbuffered=False):
- self._result = None
- if unbuffered:
- try:
- result = MySQLResult(self)
- result.init_unbuffered_query()
- except:
- result.unbuffered_active = False
- result.connection = None
- raise
- else:
- result = MySQLResult(self)
- result.read()
- self._result = result
- if result.server_status is not None:
- self.server_status = result.server_status
- return result.affected_rows
- def insert_id(self):
- if self._result:
- return self._result.insert_id
- else:
- return 0
- def _execute_command(self, command, sql):
- """
- :raise InterfaceError: If the connection is closed.
- :raise ValueError: If no username was specified.
- """
- if not self._sock:
- raise err.InterfaceError(0, "")
- # If the last query was unbuffered, make sure it finishes before
- # sending new commands
- if self._result is not None:
- if self._result.unbuffered_active:
- warnings.warn("Previous unbuffered result was left incomplete")
- self._result._finish_unbuffered_query()
- while self._result.has_next:
- self.next_result()
- self._result = None
- if isinstance(sql, str):
- sql = sql.encode(self.encoding)
- packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
- # tiny optimization: build first packet manually instead of
- # calling self..write_packet()
- prelude = struct.pack("<iB", packet_size, command)
- packet = prelude + sql[: packet_size - 1]
- self._write_bytes(packet)
- if DEBUG:
- dump_packet(packet)
- self._next_seq_id = 1
- if packet_size < MAX_PACKET_LEN:
- return
- sql = sql[packet_size - 1 :]
- while True:
- packet_size = min(MAX_PACKET_LEN, len(sql))
- self.write_packet(sql[:packet_size])
- sql = sql[packet_size:]
- if not sql and packet_size < MAX_PACKET_LEN:
- break
- def _request_authentication(self):
- # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
- if int(self.server_version.split(".", 1)[0]) >= 5:
- self.client_flag |= CLIENT.MULTI_RESULTS
- if self.user is None:
- raise ValueError("Did not specify a username")
- charset_id = charset_by_name(self.charset).id
- if isinstance(self.user, str):
- self.user = self.user.encode(self.encoding)
- data_init = struct.pack(
- "<iIB23s", self.client_flag, MAX_PACKET_LEN, charset_id, b""
- )
- if self.ssl and self.server_capabilities & CLIENT.SSL:
- self.write_packet(data_init)
- self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
- self._rfile = self._sock.makefile("rb")
- self._secure = True
- data = data_init + self.user + b"\0"
- authresp = b""
- plugin_name = None
- if self._auth_plugin_name == "":
- plugin_name = b""
- authresp = _auth.scramble_native_password(self.password, self.salt)
- elif self._auth_plugin_name == "mysql_native_password":
- plugin_name = b"mysql_native_password"
- authresp = _auth.scramble_native_password(self.password, self.salt)
- elif self._auth_plugin_name == "caching_sha2_password":
- plugin_name = b"caching_sha2_password"
- if self.password:
- if DEBUG:
- print("caching_sha2: trying fast path")
- authresp = _auth.scramble_caching_sha2(self.password, self.salt)
- else:
- if DEBUG:
- print("caching_sha2: empty password")
- elif self._auth_plugin_name == "sha256_password":
- plugin_name = b"sha256_password"
- if self.ssl and self.server_capabilities & CLIENT.SSL:
- authresp = self.password + b"\0"
- elif self.password:
- authresp = b"\1" # request public key
- else:
- authresp = b"\0" # empty password
- if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
- data += _lenenc_int(len(authresp)) + authresp
- elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
- data += struct.pack("B", len(authresp)) + authresp
- else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
- data += authresp + b"\0"
- if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
- if isinstance(self.db, str):
- self.db = self.db.encode(self.encoding)
- data += self.db + b"\0"
- if self.server_capabilities & CLIENT.PLUGIN_AUTH:
- data += (plugin_name or b"") + b"\0"
- if self.server_capabilities & CLIENT.CONNECT_ATTRS:
- connect_attrs = b""
- for k, v in self._connect_attrs.items():
- k = k.encode("utf-8")
- connect_attrs += struct.pack("B", len(k)) + k
- v = v.encode("utf-8")
- connect_attrs += struct.pack("B", len(v)) + v
- data += struct.pack("B", len(connect_attrs)) + connect_attrs
- self.write_packet(data)
- auth_packet = self._read_packet()
- # if authentication method isn't accepted the first byte
- # will have the octet 254
- if auth_packet.is_auth_switch_request():
- if DEBUG:
- print("received auth switch")
- # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
- auth_packet.read_uint8() # 0xfe packet identifier
- plugin_name = auth_packet.read_string()
- if (
- self.server_capabilities & CLIENT.PLUGIN_AUTH
- and plugin_name is not None
- ):
- auth_packet = self._process_auth(plugin_name, auth_packet)
- else:
- # send legacy handshake
- data = _auth.scramble_old_password(self.password, self.salt) + b"\0"
- self.write_packet(data)
- auth_packet = self._read_packet()
- elif auth_packet.is_extra_auth_data():
- if DEBUG:
- print("received extra data")
- # https://dev.mysql.com/doc/internals/en/successful-authentication.html
- if self._auth_plugin_name == "caching_sha2_password":
- auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
- elif self._auth_plugin_name == "sha256_password":
- auth_packet = _auth.sha256_password_auth(self, auth_packet)
- else:
- raise err.OperationalError(
- "Received extra packet for auth method %r", self._auth_plugin_name
- )
- if DEBUG:
- print("Succeed to auth")
- def _process_auth(self, plugin_name, auth_packet):
- handler = self._get_auth_plugin_handler(plugin_name)
- if handler:
- try:
- return handler.authenticate(auth_packet)
- except AttributeError:
- if plugin_name != b"dialog":
- raise err.OperationalError(
- 2059,
- "Authentication plugin '%s'"
- " not loaded: - %r missing authenticate method"
- % (plugin_name, type(handler)),
- )
- if plugin_name == b"caching_sha2_password":
- return _auth.caching_sha2_password_auth(self, auth_packet)
- elif plugin_name == b"sha256_password":
- return _auth.sha256_password_auth(self, auth_packet)
- elif plugin_name == b"mysql_native_password":
- data = _auth.scramble_native_password(self.password, auth_packet.read_all())
- elif plugin_name == b"client_ed25519":
- data = _auth.ed25519_password(self.password, auth_packet.read_all())
- elif plugin_name == b"mysql_old_password":
- data = (
- _auth.scramble_old_password(self.password, auth_packet.read_all())
- + b"\0"
- )
- elif plugin_name == b"mysql_clear_password":
- # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
- data = self.password + b"\0"
- elif plugin_name == b"dialog":
- pkt = auth_packet
- while True:
- flag = pkt.read_uint8()
- echo = (flag & 0x06) == 0x02
- last = (flag & 0x01) == 0x01
- prompt = pkt.read_all()
- if prompt == b"Password: ":
- self.write_packet(self.password + b"\0")
- elif handler:
- resp = "no response - TypeError within plugin.prompt method"
- try:
- resp = handler.prompt(echo, prompt)
- self.write_packet(resp + b"\0")
- except AttributeError:
- raise err.OperationalError(
- 2059,
- "Authentication plugin '%s'"
- " not loaded: - %r missing prompt method"
- % (plugin_name, handler),
- )
- except TypeError:
- raise err.OperationalError(
- 2061,
- "Authentication plugin '%s'"
- " %r didn't respond with string. Returned '%r' to prompt %r"
- % (plugin_name, handler, resp, prompt),
- )
- else:
- raise err.OperationalError(
- 2059,
- "Authentication plugin '%s' (%r) not configured"
- % (plugin_name, handler),
- )
- pkt = self._read_packet()
- pkt.check_error()
- if pkt.is_ok_packet() or last:
- break
- return pkt
- else:
- raise err.OperationalError(
- 2059, "Authentication plugin '%s' not configured" % plugin_name
- )
- self.write_packet(data)
- pkt = self._read_packet()
- pkt.check_error()
- return pkt
- def _get_auth_plugin_handler(self, plugin_name):
- plugin_class = self._auth_plugin_map.get(plugin_name)
- if not plugin_class and isinstance(plugin_name, bytes):
- plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii"))
- if plugin_class:
- try:
- handler = plugin_class(self)
- except TypeError:
- raise err.OperationalError(
- 2059,
- "Authentication plugin '%s'"
- " not loaded: - %r cannot be constructed with connection object"
- % (plugin_name, plugin_class),
- )
- else:
- handler = None
- return handler
- # _mysql support
- def thread_id(self):
- return self.server_thread_id[0]
- def character_set_name(self):
- return self.charset
- def get_host_info(self):
- return self.host_info
- def get_proto_info(self):
- return self.protocol_version
- def _get_server_information(self):
- i = 0
- packet = self._read_packet()
- data = packet.get_all_data()
- self.protocol_version = data[i]
- i += 1
- server_end = data.find(b"\0", i)
- self.server_version = data[i:server_end].decode("latin1")
- i = server_end + 1
- self.server_thread_id = struct.unpack("<I", data[i : i + 4])
- i += 4
- self.salt = data[i : i + 8]
- i += 9 # 8 + 1(filler)
- self.server_capabilities = struct.unpack("<H", data[i : i + 2])[0]
- i += 2
- if len(data) >= i + 6:
- lang, stat, cap_h, salt_len = struct.unpack("<BHHB", data[i : i + 6])
- i += 6
- # TODO: deprecate server_language and server_charset.
- # mysqlclient-python doesn't provide it.
- self.server_language = lang
- try:
- self.server_charset = charset_by_id(lang).name
- except KeyError:
- # unknown collation
- self.server_charset = None
- self.server_status = stat
- if DEBUG:
- print("server_status: %x" % stat)
- self.server_capabilities |= cap_h << 16
- if DEBUG:
- print("salt_len:", salt_len)
- salt_len = max(12, salt_len - 9)
- # reserved
- i += 10
- if len(data) >= i + salt_len:
- # salt_len includes auth_plugin_data_part_1 and filler
- self.salt += data[i : i + salt_len]
- i += salt_len
- i += 1
- # AUTH PLUGIN NAME may appear here.
- if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
- # Due to Bug#59453 the auth-plugin-name is missing the terminating
- # NUL-char in versions prior to 5.5.10 and 5.6.2.
- # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
- # didn't use version checks as mariadb is corrected and reports
- # earlier than those two.
- server_end = data.find(b"\0", i)
- if server_end < 0: # pragma: no cover - very specific upstream bug
- # not found \0 and last field so take it all
- self._auth_plugin_name = data[i:].decode("utf-8")
- else:
- self._auth_plugin_name = data[i:server_end].decode("utf-8")
- def get_server_info(self):
- return self.server_version
- Warning = err.Warning
- Error = err.Error
- InterfaceError = err.InterfaceError
- DatabaseError = err.DatabaseError
- DataError = err.DataError
- OperationalError = err.OperationalError
- IntegrityError = err.IntegrityError
- InternalError = err.InternalError
- ProgrammingError = err.ProgrammingError
- NotSupportedError = err.NotSupportedError
- class MySQLResult:
- def __init__(self, connection):
- """
- :type connection: Connection
- """
- self.connection = connection
- self.affected_rows = None
- self.insert_id = None
- self.server_status = None
- self.warning_count = 0
- self.message = None
- self.field_count = 0
- self.description = None
- self.rows = None
- self.has_next = None
- self.unbuffered_active = False
- def __del__(self):
- if self.unbuffered_active:
- self._finish_unbuffered_query()
- def read(self):
- try:
- first_packet = self.connection._read_packet()
- if first_packet.is_ok_packet():
- self._read_ok_packet(first_packet)
- elif first_packet.is_load_local_packet():
- self._read_load_local_packet(first_packet)
- else:
- self._read_result_packet(first_packet)
- finally:
- self.connection = None
- def init_unbuffered_query(self):
- """
- :raise OperationalError: If the connection to the MySQL server is lost.
- :raise InternalError:
- """
- self.unbuffered_active = True
- first_packet = self.connection._read_packet()
- if first_packet.is_ok_packet():
- self._read_ok_packet(first_packet)
- self.unbuffered_active = False
- self.connection = None
- elif first_packet.is_load_local_packet():
- self._read_load_local_packet(first_packet)
- self.unbuffered_active = False
- self.connection = None
- else:
- self.field_count = first_packet.read_length_encoded_integer()
- self._get_descriptions()
- # Apparently, MySQLdb picks this number because it's the maximum
- # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
- # we set it to this instead of None, which would be preferred.
- self.affected_rows = 18446744073709551615
- def _read_ok_packet(self, first_packet):
- ok_packet = OKPacketWrapper(first_packet)
- self.affected_rows = ok_packet.affected_rows
- self.insert_id = ok_packet.insert_id
- self.server_status = ok_packet.server_status
- self.warning_count = ok_packet.warning_count
- self.message = ok_packet.message
- self.has_next = ok_packet.has_next
- def _read_load_local_packet(self, first_packet):
- if not self.connection._local_infile:
- raise RuntimeError(
- "**WARN**: Received LOAD_LOCAL packet but local_infile option is false."
- )
- load_packet = LoadLocalPacketWrapper(first_packet)
- sender = LoadLocalFile(load_packet.filename, self.connection)
- try:
- sender.send_data()
- except:
- self.connection._read_packet() # skip ok packet
- raise
- ok_packet = self.connection._read_packet()
- if (
- not ok_packet.is_ok_packet()
- ): # pragma: no cover - upstream induced protocol error
- raise err.OperationalError(2014, "Commands Out of Sync")
- self._read_ok_packet(ok_packet)
- def _check_packet_is_eof(self, packet):
- if not packet.is_eof_packet():
- return False
- # TODO: Support CLIENT.DEPRECATE_EOF
- # 1) Add DEPRECATE_EOF to CAPABILITIES
- # 2) Mask CAPABILITIES with server_capabilities
- # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
- wp = EOFPacketWrapper(packet)
- self.warning_count = wp.warning_count
- self.has_next = wp.has_next
- return True
- def _read_result_packet(self, first_packet):
- self.field_count = first_packet.read_length_encoded_integer()
- self._get_descriptions()
- self._read_rowdata_packet()
- def _read_rowdata_packet_unbuffered(self):
- # Check if in an active query
- if not self.unbuffered_active:
- return
- # EOF
- packet = self.connection._read_packet()
- if self._check_packet_is_eof(packet):
- self.unbuffered_active = False
- self.connection = None
- self.rows = None
- return
- row = self._read_row_from_packet(packet)
- self.affected_rows = 1
- self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
- return row
- def _finish_unbuffered_query(self):
- # After much reading on the MySQL protocol, it appears that there is,
- # in fact, no way to stop MySQL from sending all the data after
- # executing a query, so we just spin, and wait for an EOF packet.
- while self.unbuffered_active:
- packet = self.connection._read_packet()
- if self._check_packet_is_eof(packet):
- self.unbuffered_active = False
- self.connection = None # release reference to kill cyclic reference.
- def _read_rowdata_packet(self):
- """Read a rowdata packet for each data row in the result set."""
- rows = []
- while True:
- packet = self.connection._read_packet()
- if self._check_packet_is_eof(packet):
- self.connection = None # release reference to kill cyclic reference.
- break
- rows.append(self._read_row_from_packet(packet))
- self.affected_rows = len(rows)
- self.rows = tuple(rows)
- def _read_row_from_packet(self, packet):
- row = []
- for encoding, converter in self.converters:
- try:
- data = packet.read_length_coded_string()
- except IndexError:
- # No more columns in this row
- # See https://github.com/PyMySQL/PyMySQL/pull/434
- break
- if data is not None:
- if encoding is not None:
- data = data.decode(encoding)
- if DEBUG:
- print("DEBUG: DATA = ", data)
- if converter is not None:
- data = converter(data)
- row.append(data)
- return tuple(row)
- def _get_descriptions(self):
- """Read a column descriptor packet for each column in the result."""
- self.fields = []
- self.converters = []
- use_unicode = self.connection.use_unicode
- conn_encoding = self.connection.encoding
- description = []
- for i in range(self.field_count):
- field = self.connection._read_packet(FieldDescriptorPacket)
- self.fields.append(field)
- description.append(field.description())
- field_type = field.type_code
- if use_unicode:
- if field_type == FIELD_TYPE.JSON:
- # When SELECT from JSON column: charset = binary
- # When SELECT CAST(... AS JSON): charset = connection encoding
- # This behavior is different from TEXT / BLOB.
- # We should decode result by connection encoding regardless charsetnr.
- # See https://github.com/PyMySQL/PyMySQL/issues/488
- encoding = conn_encoding # SELECT CAST(... AS JSON)
- elif field_type in TEXT_TYPES:
- if field.charsetnr == 63: # binary
- # TEXTs with charset=binary means BINARY types.
- encoding = None
- else:
- encoding = conn_encoding
- else:
- # Integers, Dates and Times, and other basic data is encoded in ascii
- encoding = "ascii"
- else:
- encoding = None
- converter = self.connection.decoders.get(field_type)
- if converter is converters.through:
- converter = None
- if DEBUG:
- print(f"DEBUG: field={field}, converter={converter}")
- self.converters.append((encoding, converter))
- eof_packet = self.connection._read_packet()
- assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF"
- self.description = tuple(description)
- class LoadLocalFile:
- def __init__(self, filename, connection):
- self.filename = filename
- self.connection = connection
- def send_data(self):
- """Send data packets from the local file to the server"""
- if not self.connection._sock:
- raise err.InterfaceError(0, "")
- conn = self.connection
- try:
- with open(self.filename, "rb") as open_file:
- packet_size = min(
- conn.max_allowed_packet, 16 * 1024
- ) # 16KB is efficient enough
- while True:
- chunk = open_file.read(packet_size)
- if not chunk:
- break
- conn.write_packet(chunk)
- except IOError:
- raise err.OperationalError(1017, f"Can't find file '{self.filename}'")
- finally:
- # send the empty packet to signify we are done sending data
- conn.write_packet(b"")
|