_auth.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """
  2. Implements auth methods
  3. """
  4. from .err import OperationalError
  5. try:
  6. from cryptography.hazmat.backends import default_backend
  7. from cryptography.hazmat.primitives import serialization, hashes
  8. from cryptography.hazmat.primitives.asymmetric import padding
  9. _have_cryptography = True
  10. except ImportError:
  11. _have_cryptography = False
  12. from functools import partial
  13. import hashlib
  14. DEBUG = False
  15. SCRAMBLE_LENGTH = 20
  16. sha1_new = partial(hashlib.new, "sha1")
  17. # mysql_native_password
  18. # https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
  19. def scramble_native_password(password, message):
  20. """Scramble used for mysql_native_password"""
  21. if not password:
  22. return b""
  23. stage1 = sha1_new(password).digest()
  24. stage2 = sha1_new(stage1).digest()
  25. s = sha1_new()
  26. s.update(message[:SCRAMBLE_LENGTH])
  27. s.update(stage2)
  28. result = s.digest()
  29. return _my_crypt(result, stage1)
  30. def _my_crypt(message1, message2):
  31. result = bytearray(message1)
  32. for i in range(len(result)):
  33. result[i] ^= message2[i]
  34. return bytes(result)
  35. # MariaDB's client_ed25519-plugin
  36. # https://mariadb.com/kb/en/library/connection/#client_ed25519-plugin
  37. _nacl_bindings = False
  38. def _init_nacl():
  39. global _nacl_bindings
  40. try:
  41. from nacl import bindings
  42. _nacl_bindings = bindings
  43. except ImportError:
  44. raise RuntimeError(
  45. "'pynacl' package is required for ed25519_password auth method"
  46. )
  47. def _scalar_clamp(s32):
  48. ba = bytearray(s32)
  49. ba0 = bytes(bytearray([ba[0] & 248]))
  50. ba31 = bytes(bytearray([(ba[31] & 127) | 64]))
  51. return ba0 + bytes(s32[1:31]) + ba31
  52. def ed25519_password(password, scramble):
  53. """Sign a random scramble with elliptic curve Ed25519.
  54. Secret and public key are derived from password.
  55. """
  56. # variable names based on rfc8032 section-5.1.6
  57. #
  58. if not _nacl_bindings:
  59. _init_nacl()
  60. # h = SHA512(password)
  61. h = hashlib.sha512(password).digest()
  62. # s = prune(first_half(h))
  63. s = _scalar_clamp(h[:32])
  64. # r = SHA512(second_half(h) || M)
  65. r = hashlib.sha512(h[32:] + scramble).digest()
  66. # R = encoded point [r]B
  67. r = _nacl_bindings.crypto_core_ed25519_scalar_reduce(r)
  68. R = _nacl_bindings.crypto_scalarmult_ed25519_base_noclamp(r)
  69. # A = encoded point [s]B
  70. A = _nacl_bindings.crypto_scalarmult_ed25519_base_noclamp(s)
  71. # k = SHA512(R || A || M)
  72. k = hashlib.sha512(R + A + scramble).digest()
  73. # S = (k * s + r) mod L
  74. k = _nacl_bindings.crypto_core_ed25519_scalar_reduce(k)
  75. ks = _nacl_bindings.crypto_core_ed25519_scalar_mul(k, s)
  76. S = _nacl_bindings.crypto_core_ed25519_scalar_add(ks, r)
  77. # signature = R || S
  78. return R + S
  79. # sha256_password
  80. def _roundtrip(conn, send_data):
  81. conn.write_packet(send_data)
  82. pkt = conn._read_packet()
  83. pkt.check_error()
  84. return pkt
  85. def _xor_password(password, salt):
  86. # Trailing NUL character will be added in Auth Switch Request.
  87. # See https://github.com/mysql/mysql-server/blob/7d10c82196c8e45554f27c00681474a9fb86d137/sql/auth/sha2_password.cc#L939-L945
  88. salt = salt[:SCRAMBLE_LENGTH]
  89. password_bytes = bytearray(password)
  90. # salt = bytearray(salt) # for PY2 compat.
  91. salt_len = len(salt)
  92. for i in range(len(password_bytes)):
  93. password_bytes[i] ^= salt[i % salt_len]
  94. return bytes(password_bytes)
  95. def sha2_rsa_encrypt(password, salt, public_key):
  96. """Encrypt password with salt and public_key.
  97. Used for sha256_password and caching_sha2_password.
  98. """
  99. if not _have_cryptography:
  100. raise RuntimeError(
  101. "'cryptography' package is required for sha256_password or caching_sha2_password auth methods"
  102. )
  103. message = _xor_password(password + b"\0", salt)
  104. rsa_key = serialization.load_pem_public_key(public_key, default_backend())
  105. return rsa_key.encrypt(
  106. message,
  107. padding.OAEP(
  108. mgf=padding.MGF1(algorithm=hashes.SHA1()),
  109. algorithm=hashes.SHA1(),
  110. label=None,
  111. ),
  112. )
  113. def sha256_password_auth(conn, pkt):
  114. if conn._secure:
  115. if DEBUG:
  116. print("sha256: Sending plain password")
  117. data = conn.password + b"\0"
  118. return _roundtrip(conn, data)
  119. if pkt.is_auth_switch_request():
  120. conn.salt = pkt.read_all()
  121. if not conn.server_public_key and conn.password:
  122. # Request server public key
  123. if DEBUG:
  124. print("sha256: Requesting server public key")
  125. pkt = _roundtrip(conn, b"\1")
  126. if pkt.is_extra_auth_data():
  127. conn.server_public_key = pkt._data[1:]
  128. if DEBUG:
  129. print("Received public key:\n", conn.server_public_key.decode("ascii"))
  130. if conn.password:
  131. if not conn.server_public_key:
  132. raise OperationalError("Couldn't receive server's public key")
  133. data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
  134. else:
  135. data = b""
  136. return _roundtrip(conn, data)
  137. def scramble_caching_sha2(password, nonce):
  138. # (bytes, bytes) -> bytes
  139. """Scramble algorithm used in cached_sha2_password fast path.
  140. XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
  141. """
  142. if not password:
  143. return b""
  144. p1 = hashlib.sha256(password).digest()
  145. p2 = hashlib.sha256(p1).digest()
  146. p3 = hashlib.sha256(p2 + nonce).digest()
  147. res = bytearray(p1)
  148. for i in range(len(p3)):
  149. res[i] ^= p3[i]
  150. return bytes(res)
  151. def caching_sha2_password_auth(conn, pkt):
  152. # No password fast path
  153. if not conn.password:
  154. return _roundtrip(conn, b"")
  155. if pkt.is_auth_switch_request():
  156. # Try from fast auth
  157. if DEBUG:
  158. print("caching sha2: Trying fast path")
  159. conn.salt = pkt.read_all()
  160. scrambled = scramble_caching_sha2(conn.password, conn.salt)
  161. pkt = _roundtrip(conn, scrambled)
  162. # else: fast auth is tried in initial handshake
  163. if not pkt.is_extra_auth_data():
  164. raise OperationalError(
  165. "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]
  166. )
  167. # magic numbers:
  168. # 2 - request public key
  169. # 3 - fast auth succeeded
  170. # 4 - need full auth
  171. pkt.advance(1)
  172. n = pkt.read_uint8()
  173. if n == 3:
  174. if DEBUG:
  175. print("caching sha2: succeeded by fast path.")
  176. pkt = conn._read_packet()
  177. pkt.check_error() # pkt must be OK packet
  178. return pkt
  179. if n != 4:
  180. raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n)
  181. if DEBUG:
  182. print("caching sha2: Trying full auth...")
  183. if conn._secure:
  184. if DEBUG:
  185. print("caching sha2: Sending plain password via secure connection")
  186. return _roundtrip(conn, conn.password + b"\0")
  187. if not conn.server_public_key:
  188. pkt = _roundtrip(conn, b"\x02") # Request public key
  189. if not pkt.is_extra_auth_data():
  190. raise OperationalError(
  191. "caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
  192. )
  193. conn.server_public_key = pkt._data[1:]
  194. if DEBUG:
  195. print(conn.server_public_key.decode("ascii"))
  196. data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
  197. pkt = _roundtrip(conn, data)