rsa.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import abc
  5. import typing
  6. from math import gcd
  7. from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
  8. from cryptography.hazmat.backends import _get_backend
  9. from cryptography.hazmat.backends.interfaces import Backend, RSABackend
  10. from cryptography.hazmat.primitives import _serialization, hashes
  11. from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
  12. from cryptography.hazmat.primitives.asymmetric import (
  13. AsymmetricSignatureContext,
  14. AsymmetricVerificationContext,
  15. utils as asym_utils,
  16. )
  17. class RSAPrivateKey(metaclass=abc.ABCMeta):
  18. @abc.abstractmethod
  19. def signer(
  20. self, padding: AsymmetricPadding, algorithm: hashes.HashAlgorithm
  21. ) -> AsymmetricSignatureContext:
  22. """
  23. Returns an AsymmetricSignatureContext used for signing data.
  24. """
  25. @abc.abstractmethod
  26. def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
  27. """
  28. Decrypts the provided ciphertext.
  29. """
  30. @abc.abstractproperty
  31. def key_size(self) -> int:
  32. """
  33. The bit length of the public modulus.
  34. """
  35. @abc.abstractmethod
  36. def public_key(self) -> "RSAPublicKey":
  37. """
  38. The RSAPublicKey associated with this private key.
  39. """
  40. @abc.abstractmethod
  41. def sign(
  42. self,
  43. data: bytes,
  44. padding: AsymmetricPadding,
  45. algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
  46. ) -> bytes:
  47. """
  48. Signs the data.
  49. """
  50. @abc.abstractmethod
  51. def private_numbers(self) -> "RSAPrivateNumbers":
  52. """
  53. Returns an RSAPrivateNumbers.
  54. """
  55. @abc.abstractmethod
  56. def private_bytes(
  57. self,
  58. encoding: _serialization.Encoding,
  59. format: _serialization.PrivateFormat,
  60. encryption_algorithm: _serialization.KeySerializationEncryption,
  61. ) -> bytes:
  62. """
  63. Returns the key serialized as bytes.
  64. """
  65. RSAPrivateKeyWithSerialization = RSAPrivateKey
  66. class RSAPublicKey(metaclass=abc.ABCMeta):
  67. @abc.abstractmethod
  68. def verifier(
  69. self,
  70. signature: bytes,
  71. padding: AsymmetricPadding,
  72. algorithm: hashes.HashAlgorithm,
  73. ) -> AsymmetricVerificationContext:
  74. """
  75. Returns an AsymmetricVerificationContext used for verifying signatures.
  76. """
  77. @abc.abstractmethod
  78. def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
  79. """
  80. Encrypts the given plaintext.
  81. """
  82. @abc.abstractproperty
  83. def key_size(self) -> int:
  84. """
  85. The bit length of the public modulus.
  86. """
  87. @abc.abstractmethod
  88. def public_numbers(self) -> "RSAPublicNumbers":
  89. """
  90. Returns an RSAPublicNumbers
  91. """
  92. @abc.abstractmethod
  93. def public_bytes(
  94. self,
  95. encoding: _serialization.Encoding,
  96. format: _serialization.PublicFormat,
  97. ) -> bytes:
  98. """
  99. Returns the key serialized as bytes.
  100. """
  101. @abc.abstractmethod
  102. def verify(
  103. self,
  104. signature: bytes,
  105. data: bytes,
  106. padding: AsymmetricPadding,
  107. algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
  108. ) -> None:
  109. """
  110. Verifies the signature of the data.
  111. """
  112. @abc.abstractmethod
  113. def recover_data_from_signature(
  114. self,
  115. signature: bytes,
  116. padding: AsymmetricPadding,
  117. algorithm: typing.Optional[hashes.HashAlgorithm],
  118. ) -> bytes:
  119. """
  120. Recovers the original data from the signature.
  121. """
  122. RSAPublicKeyWithSerialization = RSAPublicKey
  123. def generate_private_key(
  124. public_exponent: int,
  125. key_size: int,
  126. backend: typing.Optional[Backend] = None,
  127. ) -> RSAPrivateKey:
  128. backend = _get_backend(backend)
  129. if not isinstance(backend, RSABackend):
  130. raise UnsupportedAlgorithm(
  131. "Backend object does not implement RSABackend.",
  132. _Reasons.BACKEND_MISSING_INTERFACE,
  133. )
  134. _verify_rsa_parameters(public_exponent, key_size)
  135. return backend.generate_rsa_private_key(public_exponent, key_size)
  136. def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
  137. if public_exponent not in (3, 65537):
  138. raise ValueError(
  139. "public_exponent must be either 3 (for legacy compatibility) or "
  140. "65537. Almost everyone should choose 65537 here!"
  141. )
  142. if key_size < 512:
  143. raise ValueError("key_size must be at least 512-bits.")
  144. def _check_private_key_components(
  145. p: int,
  146. q: int,
  147. private_exponent: int,
  148. dmp1: int,
  149. dmq1: int,
  150. iqmp: int,
  151. public_exponent: int,
  152. modulus: int,
  153. ) -> None:
  154. if modulus < 3:
  155. raise ValueError("modulus must be >= 3.")
  156. if p >= modulus:
  157. raise ValueError("p must be < modulus.")
  158. if q >= modulus:
  159. raise ValueError("q must be < modulus.")
  160. if dmp1 >= modulus:
  161. raise ValueError("dmp1 must be < modulus.")
  162. if dmq1 >= modulus:
  163. raise ValueError("dmq1 must be < modulus.")
  164. if iqmp >= modulus:
  165. raise ValueError("iqmp must be < modulus.")
  166. if private_exponent >= modulus:
  167. raise ValueError("private_exponent must be < modulus.")
  168. if public_exponent < 3 or public_exponent >= modulus:
  169. raise ValueError("public_exponent must be >= 3 and < modulus.")
  170. if public_exponent & 1 == 0:
  171. raise ValueError("public_exponent must be odd.")
  172. if dmp1 & 1 == 0:
  173. raise ValueError("dmp1 must be odd.")
  174. if dmq1 & 1 == 0:
  175. raise ValueError("dmq1 must be odd.")
  176. if p * q != modulus:
  177. raise ValueError("p*q must equal modulus.")
  178. def _check_public_key_components(e: int, n: int) -> None:
  179. if n < 3:
  180. raise ValueError("n must be >= 3.")
  181. if e < 3 or e >= n:
  182. raise ValueError("e must be >= 3 and < n.")
  183. if e & 1 == 0:
  184. raise ValueError("e must be odd.")
  185. def _modinv(e: int, m: int) -> int:
  186. """
  187. Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
  188. """
  189. x1, x2 = 1, 0
  190. a, b = e, m
  191. while b > 0:
  192. q, r = divmod(a, b)
  193. xn = x1 - q * x2
  194. a, b, x1, x2 = b, r, x2, xn
  195. return x1 % m
  196. def rsa_crt_iqmp(p: int, q: int) -> int:
  197. """
  198. Compute the CRT (q ** -1) % p value from RSA primes p and q.
  199. """
  200. return _modinv(q, p)
  201. def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
  202. """
  203. Compute the CRT private_exponent % (p - 1) value from the RSA
  204. private_exponent (d) and p.
  205. """
  206. return private_exponent % (p - 1)
  207. def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
  208. """
  209. Compute the CRT private_exponent % (q - 1) value from the RSA
  210. private_exponent (d) and q.
  211. """
  212. return private_exponent % (q - 1)
  213. # Controls the number of iterations rsa_recover_prime_factors will perform
  214. # to obtain the prime factors. Each iteration increments by 2 so the actual
  215. # maximum attempts is half this number.
  216. _MAX_RECOVERY_ATTEMPTS = 1000
  217. def rsa_recover_prime_factors(
  218. n: int, e: int, d: int
  219. ) -> typing.Tuple[int, int]:
  220. """
  221. Compute factors p and q from the private exponent d. We assume that n has
  222. no more than two factors. This function is adapted from code in PyCrypto.
  223. """
  224. # See 8.2.2(i) in Handbook of Applied Cryptography.
  225. ktot = d * e - 1
  226. # The quantity d*e-1 is a multiple of phi(n), even,
  227. # and can be represented as t*2^s.
  228. t = ktot
  229. while t % 2 == 0:
  230. t = t // 2
  231. # Cycle through all multiplicative inverses in Zn.
  232. # The algorithm is non-deterministic, but there is a 50% chance
  233. # any candidate a leads to successful factoring.
  234. # See "Digitalized Signatures and Public Key Functions as Intractable
  235. # as Factorization", M. Rabin, 1979
  236. spotted = False
  237. a = 2
  238. while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
  239. k = t
  240. # Cycle through all values a^{t*2^i}=a^k
  241. while k < ktot:
  242. cand = pow(a, k, n)
  243. # Check if a^k is a non-trivial root of unity (mod n)
  244. if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
  245. # We have found a number such that (cand-1)(cand+1)=0 (mod n).
  246. # Either of the terms divides n.
  247. p = gcd(cand + 1, n)
  248. spotted = True
  249. break
  250. k *= 2
  251. # This value was not any good... let's try another!
  252. a += 2
  253. if not spotted:
  254. raise ValueError("Unable to compute factors p and q from exponent d.")
  255. # Found !
  256. q, r = divmod(n, p)
  257. assert r == 0
  258. p, q = sorted((p, q), reverse=True)
  259. return (p, q)
  260. class RSAPrivateNumbers(object):
  261. def __init__(
  262. self,
  263. p: int,
  264. q: int,
  265. d: int,
  266. dmp1: int,
  267. dmq1: int,
  268. iqmp: int,
  269. public_numbers: "RSAPublicNumbers",
  270. ):
  271. if (
  272. not isinstance(p, int)
  273. or not isinstance(q, int)
  274. or not isinstance(d, int)
  275. or not isinstance(dmp1, int)
  276. or not isinstance(dmq1, int)
  277. or not isinstance(iqmp, int)
  278. ):
  279. raise TypeError(
  280. "RSAPrivateNumbers p, q, d, dmp1, dmq1, iqmp arguments must"
  281. " all be an integers."
  282. )
  283. if not isinstance(public_numbers, RSAPublicNumbers):
  284. raise TypeError(
  285. "RSAPrivateNumbers public_numbers must be an RSAPublicNumbers"
  286. " instance."
  287. )
  288. self._p = p
  289. self._q = q
  290. self._d = d
  291. self._dmp1 = dmp1
  292. self._dmq1 = dmq1
  293. self._iqmp = iqmp
  294. self._public_numbers = public_numbers
  295. p = property(lambda self: self._p)
  296. q = property(lambda self: self._q)
  297. d = property(lambda self: self._d)
  298. dmp1 = property(lambda self: self._dmp1)
  299. dmq1 = property(lambda self: self._dmq1)
  300. iqmp = property(lambda self: self._iqmp)
  301. public_numbers = property(lambda self: self._public_numbers)
  302. def private_key(
  303. self, backend: typing.Optional[Backend] = None
  304. ) -> RSAPrivateKey:
  305. backend = _get_backend(backend)
  306. return backend.load_rsa_private_numbers(self)
  307. def __eq__(self, other):
  308. if not isinstance(other, RSAPrivateNumbers):
  309. return NotImplemented
  310. return (
  311. self.p == other.p
  312. and self.q == other.q
  313. and self.d == other.d
  314. and self.dmp1 == other.dmp1
  315. and self.dmq1 == other.dmq1
  316. and self.iqmp == other.iqmp
  317. and self.public_numbers == other.public_numbers
  318. )
  319. def __ne__(self, other):
  320. return not self == other
  321. def __hash__(self):
  322. return hash(
  323. (
  324. self.p,
  325. self.q,
  326. self.d,
  327. self.dmp1,
  328. self.dmq1,
  329. self.iqmp,
  330. self.public_numbers,
  331. )
  332. )
  333. class RSAPublicNumbers(object):
  334. def __init__(self, e: int, n: int):
  335. if not isinstance(e, int) or not isinstance(n, int):
  336. raise TypeError("RSAPublicNumbers arguments must be integers.")
  337. self._e = e
  338. self._n = n
  339. e = property(lambda self: self._e)
  340. n = property(lambda self: self._n)
  341. def public_key(
  342. self, backend: typing.Optional[Backend] = None
  343. ) -> RSAPublicKey:
  344. backend = _get_backend(backend)
  345. return backend.load_rsa_public_numbers(self)
  346. def __repr__(self):
  347. return "<RSAPublicNumbers(e={0.e}, n={0.n})>".format(self)
  348. def __eq__(self, other):
  349. if not isinstance(other, RSAPublicNumbers):
  350. return NotImplemented
  351. return self.e == other.e and self.n == other.n
  352. def __ne__(self, other):
  353. return not self == other
  354. def __hash__(self):
  355. return hash((self.e, self.n))