concatkdf.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import struct
  5. import typing
  6. from cryptography import utils
  7. from cryptography.exceptions import (
  8. AlreadyFinalized,
  9. InvalidKey,
  10. UnsupportedAlgorithm,
  11. _Reasons,
  12. )
  13. from cryptography.hazmat.backends import _get_backend
  14. from cryptography.hazmat.backends.interfaces import (
  15. Backend,
  16. HMACBackend,
  17. HashBackend,
  18. )
  19. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  20. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  21. def _int_to_u32be(n: int) -> bytes:
  22. return struct.pack(">I", n)
  23. def _common_args_checks(
  24. algorithm: hashes.HashAlgorithm,
  25. length: int,
  26. otherinfo: typing.Optional[bytes],
  27. ) -> None:
  28. max_length = algorithm.digest_size * (2 ** 32 - 1)
  29. if length > max_length:
  30. raise ValueError(
  31. "Cannot derive keys larger than {} bits.".format(max_length)
  32. )
  33. if otherinfo is not None:
  34. utils._check_bytes("otherinfo", otherinfo)
  35. def _concatkdf_derive(
  36. key_material: bytes,
  37. length: int,
  38. auxfn: typing.Callable[[], hashes.HashContext],
  39. otherinfo: bytes,
  40. ) -> bytes:
  41. utils._check_byteslike("key_material", key_material)
  42. output = [b""]
  43. outlen = 0
  44. counter = 1
  45. while length > outlen:
  46. h = auxfn()
  47. h.update(_int_to_u32be(counter))
  48. h.update(key_material)
  49. h.update(otherinfo)
  50. output.append(h.finalize())
  51. outlen += len(output[-1])
  52. counter += 1
  53. return b"".join(output)[:length]
  54. class ConcatKDFHash(KeyDerivationFunction):
  55. def __init__(
  56. self,
  57. algorithm: hashes.HashAlgorithm,
  58. length: int,
  59. otherinfo: typing.Optional[bytes],
  60. backend: typing.Optional[Backend] = None,
  61. ):
  62. backend = _get_backend(backend)
  63. _common_args_checks(algorithm, length, otherinfo)
  64. self._algorithm = algorithm
  65. self._length = length
  66. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  67. if not isinstance(backend, HashBackend):
  68. raise UnsupportedAlgorithm(
  69. "Backend object does not implement HashBackend.",
  70. _Reasons.BACKEND_MISSING_INTERFACE,
  71. )
  72. self._backend = backend
  73. self._used = False
  74. def _hash(self) -> hashes.Hash:
  75. return hashes.Hash(self._algorithm, self._backend)
  76. def derive(self, key_material: bytes) -> bytes:
  77. if self._used:
  78. raise AlreadyFinalized
  79. self._used = True
  80. return _concatkdf_derive(
  81. key_material, self._length, self._hash, self._otherinfo
  82. )
  83. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  84. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  85. raise InvalidKey
  86. class ConcatKDFHMAC(KeyDerivationFunction):
  87. def __init__(
  88. self,
  89. algorithm: hashes.HashAlgorithm,
  90. length: int,
  91. salt: typing.Optional[bytes],
  92. otherinfo: typing.Optional[bytes],
  93. backend: typing.Optional[Backend] = None,
  94. ):
  95. backend = _get_backend(backend)
  96. _common_args_checks(algorithm, length, otherinfo)
  97. self._algorithm = algorithm
  98. self._length = length
  99. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  100. if algorithm.block_size is None:
  101. raise TypeError(
  102. "{} is unsupported for ConcatKDF".format(algorithm.name)
  103. )
  104. if salt is None:
  105. salt = b"\x00" * algorithm.block_size
  106. else:
  107. utils._check_bytes("salt", salt)
  108. self._salt = salt
  109. if not isinstance(backend, HMACBackend):
  110. raise UnsupportedAlgorithm(
  111. "Backend object does not implement HMACBackend.",
  112. _Reasons.BACKEND_MISSING_INTERFACE,
  113. )
  114. self._backend = backend
  115. self._used = False
  116. def _hmac(self) -> hmac.HMAC:
  117. return hmac.HMAC(self._salt, self._algorithm, self._backend)
  118. def derive(self, key_material: bytes) -> bytes:
  119. if self._used:
  120. raise AlreadyFinalized
  121. self._used = True
  122. return _concatkdf_derive(
  123. key_material, self._length, self._hmac, self._otherinfo
  124. )
  125. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  126. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  127. raise InvalidKey