aead.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from cryptography.exceptions import InvalidTag
  5. _ENCRYPT = 1
  6. _DECRYPT = 0
  7. def _aead_cipher_name(cipher):
  8. from cryptography.hazmat.primitives.ciphers.aead import (
  9. AESCCM,
  10. AESGCM,
  11. ChaCha20Poly1305,
  12. )
  13. if isinstance(cipher, ChaCha20Poly1305):
  14. return b"chacha20-poly1305"
  15. elif isinstance(cipher, AESCCM):
  16. return "aes-{}-ccm".format(len(cipher._key) * 8).encode("ascii")
  17. else:
  18. assert isinstance(cipher, AESGCM)
  19. return "aes-{}-gcm".format(len(cipher._key) * 8).encode("ascii")
  20. def _aead_setup(backend, cipher_name, key, nonce, tag, tag_len, operation):
  21. evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
  22. backend.openssl_assert(evp_cipher != backend._ffi.NULL)
  23. ctx = backend._lib.EVP_CIPHER_CTX_new()
  24. ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
  25. res = backend._lib.EVP_CipherInit_ex(
  26. ctx,
  27. evp_cipher,
  28. backend._ffi.NULL,
  29. backend._ffi.NULL,
  30. backend._ffi.NULL,
  31. int(operation == _ENCRYPT),
  32. )
  33. backend.openssl_assert(res != 0)
  34. res = backend._lib.EVP_CIPHER_CTX_set_key_length(ctx, len(key))
  35. backend.openssl_assert(res != 0)
  36. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  37. ctx,
  38. backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
  39. len(nonce),
  40. backend._ffi.NULL,
  41. )
  42. backend.openssl_assert(res != 0)
  43. if operation == _DECRYPT:
  44. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  45. ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
  46. )
  47. backend.openssl_assert(res != 0)
  48. elif cipher_name.endswith(b"-ccm"):
  49. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  50. ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, tag_len, backend._ffi.NULL
  51. )
  52. backend.openssl_assert(res != 0)
  53. nonce_ptr = backend._ffi.from_buffer(nonce)
  54. key_ptr = backend._ffi.from_buffer(key)
  55. res = backend._lib.EVP_CipherInit_ex(
  56. ctx,
  57. backend._ffi.NULL,
  58. backend._ffi.NULL,
  59. key_ptr,
  60. nonce_ptr,
  61. int(operation == _ENCRYPT),
  62. )
  63. backend.openssl_assert(res != 0)
  64. return ctx
  65. def _set_length(backend, ctx, data_len):
  66. intptr = backend._ffi.new("int *")
  67. res = backend._lib.EVP_CipherUpdate(
  68. ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len
  69. )
  70. backend.openssl_assert(res != 0)
  71. def _process_aad(backend, ctx, associated_data):
  72. outlen = backend._ffi.new("int *")
  73. res = backend._lib.EVP_CipherUpdate(
  74. ctx, backend._ffi.NULL, outlen, associated_data, len(associated_data)
  75. )
  76. backend.openssl_assert(res != 0)
  77. def _process_data(backend, ctx, data):
  78. outlen = backend._ffi.new("int *")
  79. buf = backend._ffi.new("unsigned char[]", len(data))
  80. res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
  81. backend.openssl_assert(res != 0)
  82. return backend._ffi.buffer(buf, outlen[0])[:]
  83. def _encrypt(backend, cipher, nonce, data, associated_data, tag_length):
  84. from cryptography.hazmat.primitives.ciphers.aead import AESCCM
  85. cipher_name = _aead_cipher_name(cipher)
  86. ctx = _aead_setup(
  87. backend, cipher_name, cipher._key, nonce, None, tag_length, _ENCRYPT
  88. )
  89. # CCM requires us to pass the length of the data before processing anything
  90. # However calling this with any other AEAD results in an error
  91. if isinstance(cipher, AESCCM):
  92. _set_length(backend, ctx, len(data))
  93. _process_aad(backend, ctx, associated_data)
  94. processed_data = _process_data(backend, ctx, data)
  95. outlen = backend._ffi.new("int *")
  96. res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
  97. backend.openssl_assert(res != 0)
  98. backend.openssl_assert(outlen[0] == 0)
  99. tag_buf = backend._ffi.new("unsigned char[]", tag_length)
  100. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  101. ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf
  102. )
  103. backend.openssl_assert(res != 0)
  104. tag = backend._ffi.buffer(tag_buf)[:]
  105. return processed_data + tag
  106. def _decrypt(backend, cipher, nonce, data, associated_data, tag_length):
  107. from cryptography.hazmat.primitives.ciphers.aead import AESCCM
  108. if len(data) < tag_length:
  109. raise InvalidTag
  110. tag = data[-tag_length:]
  111. data = data[:-tag_length]
  112. cipher_name = _aead_cipher_name(cipher)
  113. ctx = _aead_setup(
  114. backend, cipher_name, cipher._key, nonce, tag, tag_length, _DECRYPT
  115. )
  116. # CCM requires us to pass the length of the data before processing anything
  117. # However calling this with any other AEAD results in an error
  118. if isinstance(cipher, AESCCM):
  119. _set_length(backend, ctx, len(data))
  120. _process_aad(backend, ctx, associated_data)
  121. # CCM has a different error path if the tag doesn't match. Errors are
  122. # raised in Update and Final is irrelevant.
  123. if isinstance(cipher, AESCCM):
  124. outlen = backend._ffi.new("int *")
  125. buf = backend._ffi.new("unsigned char[]", len(data))
  126. res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
  127. if res != 1:
  128. backend._consume_errors()
  129. raise InvalidTag
  130. processed_data = backend._ffi.buffer(buf, outlen[0])[:]
  131. else:
  132. processed_data = _process_data(backend, ctx, data)
  133. outlen = backend._ffi.new("int *")
  134. res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
  135. if res == 0:
  136. backend._consume_errors()
  137. raise InvalidTag
  138. return processed_data