ssh.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import binascii
  5. import os
  6. import re
  7. import struct
  8. import typing
  9. from base64 import encodebytes as _base64_encode
  10. from cryptography import utils
  11. from cryptography.exceptions import UnsupportedAlgorithm
  12. from cryptography.hazmat.backends import _get_backend
  13. from cryptography.hazmat.backends.interfaces import Backend
  14. from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
  15. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  16. from cryptography.hazmat.primitives.serialization import (
  17. Encoding,
  18. NoEncryption,
  19. PrivateFormat,
  20. PublicFormat,
  21. )
  22. try:
  23. from bcrypt import kdf as _bcrypt_kdf
  24. _bcrypt_supported = True
  25. except ImportError:
  26. _bcrypt_supported = False
  27. def _bcrypt_kdf(
  28. password: bytes,
  29. salt: bytes,
  30. desired_key_bytes: int,
  31. rounds: int,
  32. ignore_few_rounds: bool = False,
  33. ) -> bytes:
  34. raise UnsupportedAlgorithm("Need bcrypt module")
  35. _SSH_ED25519 = b"ssh-ed25519"
  36. _SSH_RSA = b"ssh-rsa"
  37. _SSH_DSA = b"ssh-dss"
  38. _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
  39. _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
  40. _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
  41. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  42. _SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
  43. _SK_MAGIC = b"openssh-key-v1\0"
  44. _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
  45. _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
  46. _BCRYPT = b"bcrypt"
  47. _NONE = b"none"
  48. _DEFAULT_CIPHER = b"aes256-ctr"
  49. _DEFAULT_ROUNDS = 16
  50. _MAX_PASSWORD = 72
  51. # re is only way to work on bytes-like data
  52. _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
  53. # padding for max blocksize
  54. _PADDING = memoryview(bytearray(range(1, 1 + 16)))
  55. # ciphers that are actually used in key wrapping
  56. _SSH_CIPHERS = {
  57. b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
  58. b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
  59. }
  60. # map local curve name to key type
  61. _ECDSA_KEY_TYPE = {
  62. "secp256r1": _ECDSA_NISTP256,
  63. "secp384r1": _ECDSA_NISTP384,
  64. "secp521r1": _ECDSA_NISTP521,
  65. }
  66. _U32 = struct.Struct(b">I")
  67. _U64 = struct.Struct(b">Q")
  68. def _ecdsa_key_type(public_key):
  69. """Return SSH key_type and curve_name for private key."""
  70. curve = public_key.curve
  71. if curve.name not in _ECDSA_KEY_TYPE:
  72. raise ValueError(
  73. "Unsupported curve for ssh private key: %r" % curve.name
  74. )
  75. return _ECDSA_KEY_TYPE[curve.name]
  76. def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"):
  77. return b"".join([prefix, _base64_encode(data), suffix])
  78. def _check_block_size(data, block_len):
  79. """Require data to be full blocks"""
  80. if not data or len(data) % block_len != 0:
  81. raise ValueError("Corrupt data: missing padding")
  82. def _check_empty(data):
  83. """All data should have been parsed."""
  84. if data:
  85. raise ValueError("Corrupt data: unparsed data")
  86. def _init_cipher(ciphername, password, salt, rounds, backend):
  87. """Generate key + iv and return cipher."""
  88. if not password:
  89. raise ValueError("Key is password-protected.")
  90. algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
  91. seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
  92. return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend)
  93. def _get_u32(data):
  94. """Uint32"""
  95. if len(data) < 4:
  96. raise ValueError("Invalid data")
  97. return _U32.unpack(data[:4])[0], data[4:]
  98. def _get_u64(data):
  99. """Uint64"""
  100. if len(data) < 8:
  101. raise ValueError("Invalid data")
  102. return _U64.unpack(data[:8])[0], data[8:]
  103. def _get_sshstr(data):
  104. """Bytes with u32 length prefix"""
  105. n, data = _get_u32(data)
  106. if n > len(data):
  107. raise ValueError("Invalid data")
  108. return data[:n], data[n:]
  109. def _get_mpint(data):
  110. """Big integer."""
  111. val, data = _get_sshstr(data)
  112. if val and val[0] > 0x7F:
  113. raise ValueError("Invalid data")
  114. return int.from_bytes(val, "big"), data
  115. def _to_mpint(val):
  116. """Storage format for signed bigint."""
  117. if val < 0:
  118. raise ValueError("negative mpint not allowed")
  119. if not val:
  120. return b""
  121. nbytes = (val.bit_length() + 8) // 8
  122. return utils.int_to_bytes(val, nbytes)
  123. class _FragList(object):
  124. """Build recursive structure without data copy."""
  125. def __init__(self, init=None):
  126. self.flist = []
  127. if init:
  128. self.flist.extend(init)
  129. def put_raw(self, val):
  130. """Add plain bytes"""
  131. self.flist.append(val)
  132. def put_u32(self, val):
  133. """Big-endian uint32"""
  134. self.flist.append(_U32.pack(val))
  135. def put_sshstr(self, val):
  136. """Bytes prefixed with u32 length"""
  137. if isinstance(val, (bytes, memoryview, bytearray)):
  138. self.put_u32(len(val))
  139. self.flist.append(val)
  140. else:
  141. self.put_u32(val.size())
  142. self.flist.extend(val.flist)
  143. def put_mpint(self, val):
  144. """Big-endian bigint prefixed with u32 length"""
  145. self.put_sshstr(_to_mpint(val))
  146. def size(self):
  147. """Current number of bytes"""
  148. return sum(map(len, self.flist))
  149. def render(self, dstbuf, pos=0):
  150. """Write into bytearray"""
  151. for frag in self.flist:
  152. flen = len(frag)
  153. start, pos = pos, pos + flen
  154. dstbuf[start:pos] = frag
  155. return pos
  156. def tobytes(self):
  157. """Return as bytes"""
  158. buf = memoryview(bytearray(self.size()))
  159. self.render(buf)
  160. return buf.tobytes()
  161. class _SSHFormatRSA(object):
  162. """Format for RSA keys.
  163. Public:
  164. mpint e, n
  165. Private:
  166. mpint n, e, d, iqmp, p, q
  167. """
  168. def get_public(self, data):
  169. """RSA public fields"""
  170. e, data = _get_mpint(data)
  171. n, data = _get_mpint(data)
  172. return (e, n), data
  173. def load_public(self, key_type, data, backend):
  174. """Make RSA public key from data."""
  175. (e, n), data = self.get_public(data)
  176. public_numbers = rsa.RSAPublicNumbers(e, n)
  177. public_key = public_numbers.public_key(backend)
  178. return public_key, data
  179. def load_private(self, data, pubfields, backend):
  180. """Make RSA private key from data."""
  181. n, data = _get_mpint(data)
  182. e, data = _get_mpint(data)
  183. d, data = _get_mpint(data)
  184. iqmp, data = _get_mpint(data)
  185. p, data = _get_mpint(data)
  186. q, data = _get_mpint(data)
  187. if (e, n) != pubfields:
  188. raise ValueError("Corrupt data: rsa field mismatch")
  189. dmp1 = rsa.rsa_crt_dmp1(d, p)
  190. dmq1 = rsa.rsa_crt_dmq1(d, q)
  191. public_numbers = rsa.RSAPublicNumbers(e, n)
  192. private_numbers = rsa.RSAPrivateNumbers(
  193. p, q, d, dmp1, dmq1, iqmp, public_numbers
  194. )
  195. private_key = private_numbers.private_key(backend)
  196. return private_key, data
  197. def encode_public(self, public_key, f_pub):
  198. """Write RSA public key"""
  199. pubn = public_key.public_numbers()
  200. f_pub.put_mpint(pubn.e)
  201. f_pub.put_mpint(pubn.n)
  202. def encode_private(self, private_key, f_priv):
  203. """Write RSA private key"""
  204. private_numbers = private_key.private_numbers()
  205. public_numbers = private_numbers.public_numbers
  206. f_priv.put_mpint(public_numbers.n)
  207. f_priv.put_mpint(public_numbers.e)
  208. f_priv.put_mpint(private_numbers.d)
  209. f_priv.put_mpint(private_numbers.iqmp)
  210. f_priv.put_mpint(private_numbers.p)
  211. f_priv.put_mpint(private_numbers.q)
  212. class _SSHFormatDSA(object):
  213. """Format for DSA keys.
  214. Public:
  215. mpint p, q, g, y
  216. Private:
  217. mpint p, q, g, y, x
  218. """
  219. def get_public(self, data):
  220. """DSA public fields"""
  221. p, data = _get_mpint(data)
  222. q, data = _get_mpint(data)
  223. g, data = _get_mpint(data)
  224. y, data = _get_mpint(data)
  225. return (p, q, g, y), data
  226. def load_public(self, key_type, data, backend):
  227. """Make DSA public key from data."""
  228. (p, q, g, y), data = self.get_public(data)
  229. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  230. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  231. self._validate(public_numbers)
  232. public_key = public_numbers.public_key(backend)
  233. return public_key, data
  234. def load_private(self, data, pubfields, backend):
  235. """Make DSA private key from data."""
  236. (p, q, g, y), data = self.get_public(data)
  237. x, data = _get_mpint(data)
  238. if (p, q, g, y) != pubfields:
  239. raise ValueError("Corrupt data: dsa field mismatch")
  240. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  241. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  242. self._validate(public_numbers)
  243. private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
  244. private_key = private_numbers.private_key(backend)
  245. return private_key, data
  246. def encode_public(self, public_key, f_pub):
  247. """Write DSA public key"""
  248. public_numbers = public_key.public_numbers()
  249. parameter_numbers = public_numbers.parameter_numbers
  250. self._validate(public_numbers)
  251. f_pub.put_mpint(parameter_numbers.p)
  252. f_pub.put_mpint(parameter_numbers.q)
  253. f_pub.put_mpint(parameter_numbers.g)
  254. f_pub.put_mpint(public_numbers.y)
  255. def encode_private(self, private_key, f_priv):
  256. """Write DSA private key"""
  257. self.encode_public(private_key.public_key(), f_priv)
  258. f_priv.put_mpint(private_key.private_numbers().x)
  259. def _validate(self, public_numbers):
  260. parameter_numbers = public_numbers.parameter_numbers
  261. if parameter_numbers.p.bit_length() != 1024:
  262. raise ValueError("SSH supports only 1024 bit DSA keys")
  263. class _SSHFormatECDSA(object):
  264. """Format for ECDSA keys.
  265. Public:
  266. str curve
  267. bytes point
  268. Private:
  269. str curve
  270. bytes point
  271. mpint secret
  272. """
  273. def __init__(self, ssh_curve_name, curve):
  274. self.ssh_curve_name = ssh_curve_name
  275. self.curve = curve
  276. def get_public(self, data):
  277. """ECDSA public fields"""
  278. curve, data = _get_sshstr(data)
  279. point, data = _get_sshstr(data)
  280. if curve != self.ssh_curve_name:
  281. raise ValueError("Curve name mismatch")
  282. if point[0] != 4:
  283. raise NotImplementedError("Need uncompressed point")
  284. return (curve, point), data
  285. def load_public(self, key_type, data, backend):
  286. """Make ECDSA public key from data."""
  287. (curve_name, point), data = self.get_public(data)
  288. public_key = ec.EllipticCurvePublicKey.from_encoded_point(
  289. self.curve, point.tobytes()
  290. )
  291. return public_key, data
  292. def load_private(self, data, pubfields, backend):
  293. """Make ECDSA private key from data."""
  294. (curve_name, point), data = self.get_public(data)
  295. secret, data = _get_mpint(data)
  296. if (curve_name, point) != pubfields:
  297. raise ValueError("Corrupt data: ecdsa field mismatch")
  298. private_key = ec.derive_private_key(secret, self.curve, backend)
  299. return private_key, data
  300. def encode_public(self, public_key, f_pub):
  301. """Write ECDSA public key"""
  302. point = public_key.public_bytes(
  303. Encoding.X962, PublicFormat.UncompressedPoint
  304. )
  305. f_pub.put_sshstr(self.ssh_curve_name)
  306. f_pub.put_sshstr(point)
  307. def encode_private(self, private_key, f_priv):
  308. """Write ECDSA private key"""
  309. public_key = private_key.public_key()
  310. private_numbers = private_key.private_numbers()
  311. self.encode_public(public_key, f_priv)
  312. f_priv.put_mpint(private_numbers.private_value)
  313. class _SSHFormatEd25519(object):
  314. """Format for Ed25519 keys.
  315. Public:
  316. bytes point
  317. Private:
  318. bytes point
  319. bytes secret_and_point
  320. """
  321. def get_public(self, data):
  322. """Ed25519 public fields"""
  323. point, data = _get_sshstr(data)
  324. return (point,), data
  325. def load_public(self, key_type, data, backend):
  326. """Make Ed25519 public key from data."""
  327. (point,), data = self.get_public(data)
  328. public_key = ed25519.Ed25519PublicKey.from_public_bytes(
  329. point.tobytes()
  330. )
  331. return public_key, data
  332. def load_private(self, data, pubfields, backend):
  333. """Make Ed25519 private key from data."""
  334. (point,), data = self.get_public(data)
  335. keypair, data = _get_sshstr(data)
  336. secret = keypair[:32]
  337. point2 = keypair[32:]
  338. if point != point2 or (point,) != pubfields:
  339. raise ValueError("Corrupt data: ed25519 field mismatch")
  340. private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
  341. return private_key, data
  342. def encode_public(self, public_key, f_pub):
  343. """Write Ed25519 public key"""
  344. raw_public_key = public_key.public_bytes(
  345. Encoding.Raw, PublicFormat.Raw
  346. )
  347. f_pub.put_sshstr(raw_public_key)
  348. def encode_private(self, private_key, f_priv):
  349. """Write Ed25519 private key"""
  350. public_key = private_key.public_key()
  351. raw_private_key = private_key.private_bytes(
  352. Encoding.Raw, PrivateFormat.Raw, NoEncryption()
  353. )
  354. raw_public_key = public_key.public_bytes(
  355. Encoding.Raw, PublicFormat.Raw
  356. )
  357. f_keypair = _FragList([raw_private_key, raw_public_key])
  358. self.encode_public(public_key, f_priv)
  359. f_priv.put_sshstr(f_keypair)
  360. _KEY_FORMATS = {
  361. _SSH_RSA: _SSHFormatRSA(),
  362. _SSH_DSA: _SSHFormatDSA(),
  363. _SSH_ED25519: _SSHFormatEd25519(),
  364. _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
  365. _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
  366. _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
  367. }
  368. def _lookup_kformat(key_type):
  369. """Return valid format or throw error"""
  370. if not isinstance(key_type, bytes):
  371. key_type = memoryview(key_type).tobytes()
  372. if key_type in _KEY_FORMATS:
  373. return _KEY_FORMATS[key_type]
  374. raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type)
  375. _SSH_PRIVATE_KEY_TYPES = typing.Union[
  376. ec.EllipticCurvePrivateKey,
  377. rsa.RSAPrivateKey,
  378. dsa.DSAPrivateKey,
  379. ed25519.Ed25519PrivateKey,
  380. ]
  381. def load_ssh_private_key(
  382. data: bytes,
  383. password: typing.Optional[bytes],
  384. backend: typing.Optional[Backend] = None,
  385. ) -> _SSH_PRIVATE_KEY_TYPES:
  386. """Load private key from OpenSSH custom encoding."""
  387. utils._check_byteslike("data", data)
  388. backend = _get_backend(backend)
  389. if password is not None:
  390. utils._check_bytes("password", password)
  391. m = _PEM_RC.search(data)
  392. if not m:
  393. raise ValueError("Not OpenSSH private key format")
  394. p1 = m.start(1)
  395. p2 = m.end(1)
  396. data = binascii.a2b_base64(memoryview(data)[p1:p2])
  397. if not data.startswith(_SK_MAGIC):
  398. raise ValueError("Not OpenSSH private key format")
  399. data = memoryview(data)[len(_SK_MAGIC) :]
  400. # parse header
  401. ciphername, data = _get_sshstr(data)
  402. kdfname, data = _get_sshstr(data)
  403. kdfoptions, data = _get_sshstr(data)
  404. nkeys, data = _get_u32(data)
  405. if nkeys != 1:
  406. raise ValueError("Only one key supported")
  407. # load public key data
  408. pubdata, data = _get_sshstr(data)
  409. pub_key_type, pubdata = _get_sshstr(pubdata)
  410. kformat = _lookup_kformat(pub_key_type)
  411. pubfields, pubdata = kformat.get_public(pubdata)
  412. _check_empty(pubdata)
  413. # load secret data
  414. edata, data = _get_sshstr(data)
  415. _check_empty(data)
  416. if (ciphername, kdfname) != (_NONE, _NONE):
  417. ciphername = ciphername.tobytes()
  418. if ciphername not in _SSH_CIPHERS:
  419. raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername)
  420. if kdfname != _BCRYPT:
  421. raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname)
  422. blklen = _SSH_CIPHERS[ciphername][3]
  423. _check_block_size(edata, blklen)
  424. salt, kbuf = _get_sshstr(kdfoptions)
  425. rounds, kbuf = _get_u32(kbuf)
  426. _check_empty(kbuf)
  427. ciph = _init_cipher(
  428. ciphername, password, salt.tobytes(), rounds, backend
  429. )
  430. edata = memoryview(ciph.decryptor().update(edata))
  431. else:
  432. blklen = 8
  433. _check_block_size(edata, blklen)
  434. ck1, edata = _get_u32(edata)
  435. ck2, edata = _get_u32(edata)
  436. if ck1 != ck2:
  437. raise ValueError("Corrupt data: broken checksum")
  438. # load per-key struct
  439. key_type, edata = _get_sshstr(edata)
  440. if key_type != pub_key_type:
  441. raise ValueError("Corrupt data: key type mismatch")
  442. private_key, edata = kformat.load_private(edata, pubfields, backend)
  443. comment, edata = _get_sshstr(edata)
  444. # yes, SSH does padding check *after* all other parsing is done.
  445. # need to follow as it writes zero-byte padding too.
  446. if edata != _PADDING[: len(edata)]:
  447. raise ValueError("Corrupt data: invalid padding")
  448. return private_key
  449. def serialize_ssh_private_key(
  450. private_key: _SSH_PRIVATE_KEY_TYPES,
  451. password: typing.Optional[bytes] = None,
  452. ) -> bytes:
  453. """Serialize private key with OpenSSH custom encoding."""
  454. if password is not None:
  455. utils._check_bytes("password", password)
  456. if password and len(password) > _MAX_PASSWORD:
  457. raise ValueError(
  458. "Passwords longer than 72 bytes are not supported by "
  459. "OpenSSH private key format"
  460. )
  461. if isinstance(private_key, ec.EllipticCurvePrivateKey):
  462. key_type = _ecdsa_key_type(private_key.public_key())
  463. elif isinstance(private_key, rsa.RSAPrivateKey):
  464. key_type = _SSH_RSA
  465. elif isinstance(private_key, dsa.DSAPrivateKey):
  466. key_type = _SSH_DSA
  467. elif isinstance(private_key, ed25519.Ed25519PrivateKey):
  468. key_type = _SSH_ED25519
  469. else:
  470. raise ValueError("Unsupported key type")
  471. kformat = _lookup_kformat(key_type)
  472. # setup parameters
  473. f_kdfoptions = _FragList()
  474. if password:
  475. ciphername = _DEFAULT_CIPHER
  476. blklen = _SSH_CIPHERS[ciphername][3]
  477. kdfname = _BCRYPT
  478. rounds = _DEFAULT_ROUNDS
  479. salt = os.urandom(16)
  480. f_kdfoptions.put_sshstr(salt)
  481. f_kdfoptions.put_u32(rounds)
  482. backend: Backend = _get_backend(None)
  483. ciph = _init_cipher(ciphername, password, salt, rounds, backend)
  484. else:
  485. ciphername = kdfname = _NONE
  486. blklen = 8
  487. ciph = None
  488. nkeys = 1
  489. checkval = os.urandom(4)
  490. comment = b""
  491. # encode public and private parts together
  492. f_public_key = _FragList()
  493. f_public_key.put_sshstr(key_type)
  494. kformat.encode_public(private_key.public_key(), f_public_key)
  495. f_secrets = _FragList([checkval, checkval])
  496. f_secrets.put_sshstr(key_type)
  497. kformat.encode_private(private_key, f_secrets)
  498. f_secrets.put_sshstr(comment)
  499. f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
  500. # top-level structure
  501. f_main = _FragList()
  502. f_main.put_raw(_SK_MAGIC)
  503. f_main.put_sshstr(ciphername)
  504. f_main.put_sshstr(kdfname)
  505. f_main.put_sshstr(f_kdfoptions)
  506. f_main.put_u32(nkeys)
  507. f_main.put_sshstr(f_public_key)
  508. f_main.put_sshstr(f_secrets)
  509. # copy result info bytearray
  510. slen = f_secrets.size()
  511. mlen = f_main.size()
  512. buf = memoryview(bytearray(mlen + blklen))
  513. f_main.render(buf)
  514. ofs = mlen - slen
  515. # encrypt in-place
  516. if ciph is not None:
  517. ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
  518. txt = _ssh_pem_encode(buf[:mlen])
  519. # Ignore the following type because mypy wants
  520. # Sequence[bytes] but what we're passing is fine.
  521. # https://github.com/python/mypy/issues/9999
  522. buf[ofs:mlen] = bytearray(slen) # type: ignore
  523. return txt
  524. _SSH_PUBLIC_KEY_TYPES = typing.Union[
  525. ec.EllipticCurvePublicKey,
  526. rsa.RSAPublicKey,
  527. dsa.DSAPublicKey,
  528. ed25519.Ed25519PublicKey,
  529. ]
  530. def load_ssh_public_key(
  531. data: bytes, backend: typing.Optional[Backend] = None
  532. ) -> _SSH_PUBLIC_KEY_TYPES:
  533. """Load public key from OpenSSH one-line format."""
  534. backend = _get_backend(backend)
  535. utils._check_byteslike("data", data)
  536. m = _SSH_PUBKEY_RC.match(data)
  537. if not m:
  538. raise ValueError("Invalid line format")
  539. key_type = orig_key_type = m.group(1)
  540. key_body = m.group(2)
  541. with_cert = False
  542. if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
  543. with_cert = True
  544. key_type = key_type[: -len(_CERT_SUFFIX)]
  545. kformat = _lookup_kformat(key_type)
  546. try:
  547. data = memoryview(binascii.a2b_base64(key_body))
  548. except (TypeError, binascii.Error):
  549. raise ValueError("Invalid key format")
  550. inner_key_type, data = _get_sshstr(data)
  551. if inner_key_type != orig_key_type:
  552. raise ValueError("Invalid key format")
  553. if with_cert:
  554. nonce, data = _get_sshstr(data)
  555. public_key, data = kformat.load_public(key_type, data, backend)
  556. if with_cert:
  557. serial, data = _get_u64(data)
  558. cctype, data = _get_u32(data)
  559. key_id, data = _get_sshstr(data)
  560. principals, data = _get_sshstr(data)
  561. valid_after, data = _get_u64(data)
  562. valid_before, data = _get_u64(data)
  563. crit_options, data = _get_sshstr(data)
  564. extensions, data = _get_sshstr(data)
  565. reserved, data = _get_sshstr(data)
  566. sig_key, data = _get_sshstr(data)
  567. signature, data = _get_sshstr(data)
  568. _check_empty(data)
  569. return public_key
  570. def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes:
  571. """One-line public key format for OpenSSH"""
  572. if isinstance(public_key, ec.EllipticCurvePublicKey):
  573. key_type = _ecdsa_key_type(public_key)
  574. elif isinstance(public_key, rsa.RSAPublicKey):
  575. key_type = _SSH_RSA
  576. elif isinstance(public_key, dsa.DSAPublicKey):
  577. key_type = _SSH_DSA
  578. elif isinstance(public_key, ed25519.Ed25519PublicKey):
  579. key_type = _SSH_ED25519
  580. else:
  581. raise ValueError("Unsupported key type")
  582. kformat = _lookup_kformat(key_type)
  583. f_pub = _FragList()
  584. f_pub.put_sshstr(key_type)
  585. kformat.encode_public(public_key, f_pub)
  586. pub = binascii.b2a_base64(f_pub.tobytes()).strip()
  587. return b"".join([key_type, b" ", pub])