123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496 |
- import re
- from . import err
- #: Regular expression for :meth:`Cursor.executemany`.
- #: executemany only supports simple bulk insert.
- #: You can use it to load large dataset.
- RE_INSERT_VALUES = re.compile(
- r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
- + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
- + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
- re.IGNORECASE | re.DOTALL,
- )
- class Cursor:
- """
- This is the object you use to interact with the database.
- Do not create an instance of a Cursor yourself. Call
- connections.Connection.cursor().
- See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in
- the specification.
- """
- #: Max statement size which :meth:`executemany` generates.
- #:
- #: Max size of allowed statement is max_allowed_packet - packet_header_size.
- #: Default value of max_allowed_packet is 1048576.
- max_stmt_length = 1024000
- def __init__(self, connection):
- self.connection = connection
- self.description = None
- self.rownumber = 0
- self.rowcount = -1
- self.arraysize = 1
- self._executed = None
- self._result = None
- self._rows = None
- def close(self):
- """
- Closing a cursor just exhausts all remaining data.
- """
- conn = self.connection
- if conn is None:
- return
- try:
- while self.nextset():
- pass
- finally:
- self.connection = None
- def __enter__(self):
- return self
- def __exit__(self, *exc_info):
- del exc_info
- self.close()
- def _get_db(self):
- if not self.connection:
- raise err.ProgrammingError("Cursor closed")
- return self.connection
- def _check_executed(self):
- if not self._executed:
- raise err.ProgrammingError("execute() first")
- def _conv_row(self, row):
- return row
- def setinputsizes(self, *args):
- """Does nothing, required by DB API."""
- def setoutputsizes(self, *args):
- """Does nothing, required by DB API."""
- def _nextset(self, unbuffered=False):
- """Get the next query set"""
- conn = self._get_db()
- current_result = self._result
- if current_result is None or current_result is not conn._result:
- return None
- if not current_result.has_next:
- return None
- self._result = None
- self._clear_result()
- conn.next_result(unbuffered=unbuffered)
- self._do_get_result()
- return True
- def nextset(self):
- return self._nextset(False)
- def _ensure_bytes(self, x, encoding=None):
- if isinstance(x, str):
- x = x.encode(encoding)
- elif isinstance(x, (tuple, list)):
- x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
- return x
- def _escape_args(self, args, conn):
- if isinstance(args, (tuple, list)):
- return tuple(conn.literal(arg) for arg in args)
- elif isinstance(args, dict):
- return {key: conn.literal(val) for (key, val) in args.items()}
- else:
- # If it's not a dictionary let's try escaping it anyways.
- # Worst case it will throw a Value error
- return conn.escape(args)
- def mogrify(self, query, args=None):
- """
- Returns the exact string that is sent to the database by calling the
- execute() method.
- This method follows the extension to the DB API 2.0 followed by Psycopg.
- """
- conn = self._get_db()
- if args is not None:
- query = query % self._escape_args(args, conn)
- return query
- def execute(self, query, args=None):
- """Execute a query
- :param str query: Query to execute.
- :param args: parameters used with query. (optional)
- :type args: tuple, list or dict
- :return: Number of affected rows
- :rtype: int
- If args is a list or tuple, %s can be used as a placeholder in the query.
- If args is a dict, %(name)s can be used as a placeholder in the query.
- """
- while self.nextset():
- pass
- query = self.mogrify(query, args)
- result = self._query(query)
- self._executed = query
- return result
- def executemany(self, query, args):
- # type: (str, list) -> int
- """Run several data against one query
- :param query: query to execute on server
- :param args: Sequence of sequences or mappings. It is used as parameter.
- :return: Number of rows affected, if any.
- This method improves performance on multiple-row INSERT and
- REPLACE. Otherwise it is equivalent to looping over args with
- execute().
- """
- if not args:
- return
- m = RE_INSERT_VALUES.match(query)
- if m:
- q_prefix = m.group(1) % ()
- q_values = m.group(2).rstrip()
- q_postfix = m.group(3) or ""
- assert q_values[0] == "(" and q_values[-1] == ")"
- return self._do_execute_many(
- q_prefix,
- q_values,
- q_postfix,
- args,
- self.max_stmt_length,
- self._get_db().encoding,
- )
- self.rowcount = sum(self.execute(query, arg) for arg in args)
- return self.rowcount
- def _do_execute_many(
- self, prefix, values, postfix, args, max_stmt_length, encoding
- ):
- conn = self._get_db()
- escape = self._escape_args
- if isinstance(prefix, str):
- prefix = prefix.encode(encoding)
- if isinstance(postfix, str):
- postfix = postfix.encode(encoding)
- sql = bytearray(prefix)
- args = iter(args)
- v = values % escape(next(args), conn)
- if isinstance(v, str):
- v = v.encode(encoding, "surrogateescape")
- sql += v
- rows = 0
- for arg in args:
- v = values % escape(arg, conn)
- if isinstance(v, str):
- v = v.encode(encoding, "surrogateescape")
- if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
- rows += self.execute(sql + postfix)
- sql = bytearray(prefix)
- else:
- sql += b","
- sql += v
- rows += self.execute(sql + postfix)
- self.rowcount = rows
- return rows
- def callproc(self, procname, args=()):
- """Execute stored procedure procname with args
- procname -- string, name of procedure to execute on server
- args -- Sequence of parameters to use with procedure
- Returns the original args.
- Compatibility warning: PEP-249 specifies that any modified
- parameters must be returned. This is currently impossible
- as they are only available by storing them in a server
- variable and then retrieved by a query. Since stored
- procedures return zero or more result sets, there is no
- reliable way to get at OUT or INOUT parameters via callproc.
- The server variables are named @_procname_n, where procname
- is the parameter above and n is the position of the parameter
- (from zero). Once all result sets generated by the procedure
- have been fetched, you can issue a SELECT @_procname_0, ...
- query using .execute() to get any OUT or INOUT values.
- Compatibility warning: The act of calling a stored procedure
- itself creates an empty result set. This appears after any
- result sets generated by the procedure. This is non-standard
- behavior with respect to the DB-API. Be sure to use nextset()
- to advance through all result sets; otherwise you may get
- disconnected.
- """
- conn = self._get_db()
- if args:
- fmt = f"@_{procname}_%d=%s"
- self._query(
- "SET %s"
- % ",".join(
- fmt % (index, conn.escape(arg)) for index, arg in enumerate(args)
- )
- )
- self.nextset()
- q = "CALL %s(%s)" % (
- procname,
- ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]),
- )
- self._query(q)
- self._executed = q
- return args
- def fetchone(self):
- """Fetch the next row"""
- self._check_executed()
- if self._rows is None or self.rownumber >= len(self._rows):
- return None
- result = self._rows[self.rownumber]
- self.rownumber += 1
- return result
- def fetchmany(self, size=None):
- """Fetch several rows"""
- self._check_executed()
- if self._rows is None:
- return ()
- end = self.rownumber + (size or self.arraysize)
- result = self._rows[self.rownumber : end]
- self.rownumber = min(end, len(self._rows))
- return result
- def fetchall(self):
- """Fetch all the rows"""
- self._check_executed()
- if self._rows is None:
- return ()
- if self.rownumber:
- result = self._rows[self.rownumber :]
- else:
- result = self._rows
- self.rownumber = len(self._rows)
- return result
- def scroll(self, value, mode="relative"):
- self._check_executed()
- if mode == "relative":
- r = self.rownumber + value
- elif mode == "absolute":
- r = value
- else:
- raise err.ProgrammingError("unknown scroll mode %s" % mode)
- if not (0 <= r < len(self._rows)):
- raise IndexError("out of range")
- self.rownumber = r
- def _query(self, q):
- conn = self._get_db()
- self._last_executed = q
- self._clear_result()
- conn.query(q)
- self._do_get_result()
- return self.rowcount
- def _clear_result(self):
- self.rownumber = 0
- self._result = None
- self.rowcount = 0
- self.description = None
- self.lastrowid = None
- self._rows = None
- def _do_get_result(self):
- conn = self._get_db()
- self._result = result = conn._result
- self.rowcount = result.affected_rows
- self.description = result.description
- self.lastrowid = result.insert_id
- self._rows = result.rows
- def __iter__(self):
- return iter(self.fetchone, None)
- Warning = err.Warning
- Error = err.Error
- InterfaceError = err.InterfaceError
- DatabaseError = err.DatabaseError
- DataError = err.DataError
- OperationalError = err.OperationalError
- IntegrityError = err.IntegrityError
- InternalError = err.InternalError
- ProgrammingError = err.ProgrammingError
- NotSupportedError = err.NotSupportedError
- class DictCursorMixin:
- # You can override this to use OrderedDict or other dict-like types.
- dict_type = dict
- def _do_get_result(self):
- super(DictCursorMixin, self)._do_get_result()
- fields = []
- if self.description:
- for f in self._result.fields:
- name = f.name
- if name in fields:
- name = f.table_name + "." + name
- fields.append(name)
- self._fields = fields
- if fields and self._rows:
- self._rows = [self._conv_row(r) for r in self._rows]
- def _conv_row(self, row):
- if row is None:
- return None
- return self.dict_type(zip(self._fields, row))
- class DictCursor(DictCursorMixin, Cursor):
- """A cursor which returns results as a dictionary"""
- class SSCursor(Cursor):
- """
- Unbuffered Cursor, mainly useful for queries that return a lot of data,
- or for connections to remote servers over a slow network.
- Instead of copying every row of data into a buffer, this will fetch
- rows as needed. The upside of this is the client uses much less memory,
- and rows are returned much faster when traveling over a slow network
- or if the result set is very big.
- There are limitations, though. The MySQL protocol doesn't support
- returning the total number of rows, so the only way to tell how many rows
- there are is to iterate over every row returned. Also, it currently isn't
- possible to scroll backwards, as only the current row is held in memory.
- """
- def _conv_row(self, row):
- return row
- def close(self):
- conn = self.connection
- if conn is None:
- return
- if self._result is not None and self._result is conn._result:
- self._result._finish_unbuffered_query()
- try:
- while self.nextset():
- pass
- finally:
- self.connection = None
- __del__ = close
- def _query(self, q):
- conn = self._get_db()
- self._last_executed = q
- self._clear_result()
- conn.query(q, unbuffered=True)
- self._do_get_result()
- return self.rowcount
- def nextset(self):
- return self._nextset(unbuffered=True)
- def read_next(self):
- """Read next row"""
- return self._conv_row(self._result._read_rowdata_packet_unbuffered())
- def fetchone(self):
- """Fetch next row"""
- self._check_executed()
- row = self.read_next()
- if row is None:
- return None
- self.rownumber += 1
- return row
- def fetchall(self):
- """
- Fetch all, as per MySQLdb. Pretty useless for large queries, as
- it is buffered. See fetchall_unbuffered(), if you want an unbuffered
- generator version of this method.
- """
- return list(self.fetchall_unbuffered())
- def fetchall_unbuffered(self):
- """
- Fetch all, implemented as a generator, which isn't to standard,
- however, it doesn't make sense to return everything in a list, as that
- would use ridiculous memory for large result sets.
- """
- return iter(self.fetchone, None)
- def __iter__(self):
- return self.fetchall_unbuffered()
- def fetchmany(self, size=None):
- """Fetch many"""
- self._check_executed()
- if size is None:
- size = self.arraysize
- rows = []
- for i in range(size):
- row = self.read_next()
- if row is None:
- break
- rows.append(row)
- self.rownumber += 1
- return rows
- def scroll(self, value, mode="relative"):
- self._check_executed()
- if mode == "relative":
- if value < 0:
- raise err.NotSupportedError(
- "Backwards scrolling not supported by this cursor"
- )
- for _ in range(value):
- self.read_next()
- self.rownumber += value
- elif mode == "absolute":
- if value < self.rownumber:
- raise err.NotSupportedError(
- "Backwards scrolling not supported by this cursor"
- )
- end = value - self.rownumber
- for _ in range(end):
- self.read_next()
- self.rownumber = value
- else:
- raise err.ProgrammingError("unknown scroll mode %s" % mode)
- class SSDictCursor(DictCursorMixin, SSCursor):
- """An unbuffered cursor, which returns results as a dictionary"""
|