message.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. The pyro wire protocol message.
  3. Pyro - Python Remote Objects. Copyright by Irmen de Jong (irmen@razorvine.net).
  4. """
  5. import hashlib
  6. import hmac
  7. import struct
  8. import logging
  9. import sys
  10. import zlib
  11. from Pyro4 import errors, constants
  12. from Pyro4.configuration import config
  13. __all__ = ["Message", "secure_compare"]
  14. log = logging.getLogger("Pyro4.message")
  15. MSG_CONNECT = 1
  16. MSG_CONNECTOK = 2
  17. MSG_CONNECTFAIL = 3
  18. MSG_INVOKE = 4
  19. MSG_RESULT = 5
  20. MSG_PING = 6
  21. FLAGS_EXCEPTION = 1 << 0
  22. FLAGS_COMPRESSED = 1 << 1
  23. FLAGS_ONEWAY = 1 << 2
  24. FLAGS_BATCH = 1 << 3
  25. FLAGS_META_ON_CONNECT = 1 << 4
  26. FLAGS_ITEMSTREAMRESULT = 1 << 5
  27. FLAGS_KEEPSERIALIZED = 1 << 6
  28. class Message(object):
  29. """
  30. Pyro write protocol message.
  31. Wire messages contains of a fixed size header, an optional set of annotation chunks,
  32. and then the payload data. This class doesn't deal with the payload data:
  33. (de)serialization and handling of that data is done elsewhere.
  34. Annotation chunks are only parsed, except the 'HMAC' chunk: that is created
  35. and validated because it is used as a message digest.
  36. The header format is::
  37. 4 id ('PYRO')
  38. 2 protocol version
  39. 2 message type
  40. 2 message flags
  41. 2 sequence number
  42. 4 data length (i.e. 2 Gb data size limitation)
  43. 2 data serialization format (serializer id)
  44. 2 annotations length (total of all chunks, 0 if no annotation chunks present)
  45. 2 (reserved)
  46. 2 checksum
  47. After the header, zero or more annotation chunks may follow, of the format::
  48. 4 id (ASCII)
  49. 2 chunk length
  50. x annotation chunk databytes
  51. After that, the actual payload data bytes follow.
  52. The sequencenumber is used to check if response messages correspond to the
  53. actual request message. This prevents the situation where Pyro would perhaps return
  54. the response data from another remote call (which would not result in an error otherwise!)
  55. This could happen for instance if the socket data stream gets out of sync, perhaps due To
  56. some form of signal that interrupts I/O.
  57. The header checksum is a simple sum of the header fields to make reasonably sure
  58. that we are dealing with an actual correct PYRO protocol header and not some random
  59. data that happens to start with the 'PYRO' protocol identifier.
  60. Pyro now uses two annotation chunks that you should not touch yourself:
  61. 'HMAC' contains the hmac digest of the message data bytes and
  62. all of the annotation chunk data bytes (except those of the HMAC chunk itself).
  63. 'CORR' contains the correlation id (guid bytes)
  64. Other chunk names are free to use for custom purposes, but Pyro has the right
  65. to reserve more of them for internal use in the future.
  66. """
  67. __slots__ = ["type", "flags", "seq", "data", "data_size", "serializer_id", "annotations", "annotations_size", "hmac_key"]
  68. header_format = '!4sHHHHiHHHH'
  69. header_size = struct.calcsize(header_format)
  70. checksum_magic = 0x34E9
  71. def __init__(self, msgType, databytes, serializer_id, flags, seq, annotations=None, hmac_key=None):
  72. self.type = msgType
  73. self.flags = flags
  74. self.seq = seq
  75. self.data = databytes
  76. self.data_size = len(self.data)
  77. self.serializer_id = serializer_id
  78. self.annotations = dict(annotations or {})
  79. self.hmac_key = hmac_key
  80. if self.hmac_key:
  81. self.annotations["HMAC"] = self.hmac() # should be done last because it calculates hmac over other annotations
  82. self.annotations_size = sum([6 + len(v) for v in self.annotations.values()])
  83. if 0 < config.MAX_MESSAGE_SIZE < (self.data_size + self.annotations_size):
  84. raise errors.MessageTooLargeError("max message size exceeded (%d where max=%d)" %
  85. (self.data_size + self.annotations_size, config.MAX_MESSAGE_SIZE))
  86. def __repr__(self):
  87. return "<%s.%s at %x; type=%d flags=%d seq=%d datasize=%d #ann=%d>" %\
  88. (self.__module__, self.__class__.__name__, id(self), self.type, self.flags, self.seq, self.data_size, len(self.annotations))
  89. def to_bytes(self):
  90. """creates a byte stream containing the header followed by annotations (if any) followed by the data"""
  91. return self.__header_bytes() + self.__annotations_bytes() + self.data
  92. def __header_bytes(self):
  93. if not (0 <= self.data_size <= 0x7fffffff):
  94. raise ValueError("invalid message size (outside range 0..2Gb)")
  95. checksum = (self.type + constants.PROTOCOL_VERSION + self.data_size + self.annotations_size +
  96. self.serializer_id + self.flags + self.seq + self.checksum_magic) & 0xffff
  97. return struct.pack(self.header_format, b"PYRO", constants.PROTOCOL_VERSION, self.type, self.flags,
  98. self.seq, self.data_size, self.serializer_id, self.annotations_size, 0, checksum)
  99. def __annotations_bytes(self):
  100. if self.annotations:
  101. a = []
  102. for k, v in self.annotations.items():
  103. if len(k) != 4:
  104. raise errors.ProtocolError("annotation key must be of length 4")
  105. if sys.version_info >= (3, 0):
  106. k = k.encode("ASCII")
  107. a.append(struct.pack("!4sH", k, len(v)))
  108. a.append(v)
  109. return b"".join(a)
  110. return b""
  111. # Note: this 'chunked' way of sending is not used because it triggers Nagle's algorithm
  112. # on some systems (linux). This causes big delays, unless you change the socket option
  113. # TCP_NODELAY to disable the algorithm. What also works, is sending all the message bytes
  114. # in one go: connection.send(message.to_bytes()). This is what Pyro does.
  115. def send(self, connection):
  116. """send the message as bytes over the connection"""
  117. connection.send(self.__header_bytes())
  118. if self.annotations:
  119. connection.send(self.__annotations_bytes())
  120. connection.send(self.data)
  121. @classmethod
  122. def from_header(cls, headerData):
  123. """Parses a message header. Does not yet process the annotations chunks and message data."""
  124. if not headerData or len(headerData) != cls.header_size:
  125. raise errors.ProtocolError("header data size mismatch")
  126. tag, ver, msg_type, flags, seq, data_size, serializer_id, anns_size, _, checksum = struct.unpack(cls.header_format, headerData)
  127. if tag != b"PYRO" or ver != constants.PROTOCOL_VERSION:
  128. raise errors.ProtocolError("invalid data or unsupported protocol version")
  129. if checksum != (msg_type + ver + data_size + anns_size + flags + serializer_id + seq + cls.checksum_magic) & 0xffff:
  130. raise errors.ProtocolError("header checksum mismatch")
  131. msg = Message(msg_type, b"", serializer_id, flags, seq)
  132. msg.data_size = data_size
  133. msg.annotations_size = anns_size
  134. return msg
  135. @classmethod
  136. def recv(cls, connection, requiredMsgTypes=None, hmac_key=None):
  137. """
  138. Receives a pyro message from a given connection.
  139. Accepts the given message types (None=any, or pass a sequence).
  140. Also reads annotation chunks and the actual payload data.
  141. Validates a HMAC chunk if present.
  142. """
  143. msg = cls.from_header(connection.recv(cls.header_size))
  144. msg.hmac_key = hmac_key
  145. if 0 < config.MAX_MESSAGE_SIZE < (msg.data_size + msg.annotations_size):
  146. errorMsg = "max message size exceeded (%d where max=%d)" % (msg.data_size + msg.annotations_size, config.MAX_MESSAGE_SIZE)
  147. log.error("connection " + str(connection) + ": " + errorMsg)
  148. connection.close() # close the socket because at this point we can't return the correct seqnr for returning an errormsg
  149. exc = errors.MessageTooLargeError(errorMsg)
  150. exc.pyroMsg = msg
  151. raise exc
  152. if requiredMsgTypes and msg.type not in requiredMsgTypes:
  153. err = "invalid msg type %d received" % msg.type
  154. log.error(err)
  155. exc = errors.ProtocolError(err)
  156. exc.pyroMsg = msg
  157. raise exc
  158. if msg.annotations_size:
  159. # read annotation chunks
  160. annotations_data = connection.recv(msg.annotations_size)
  161. msg.annotations = {}
  162. i = 0
  163. while i < msg.annotations_size:
  164. anno, length = struct.unpack("!4sH", annotations_data[i:i + 6])
  165. if sys.version_info >= (3, 0):
  166. anno = anno.decode("ASCII")
  167. msg.annotations[anno] = annotations_data[i + 6:i + 6 + length]
  168. if sys.platform == "cli":
  169. msg.annotations[anno] = bytes(msg.annotations[anno])
  170. i += 6 + length
  171. # read data
  172. msg.data = connection.recv(msg.data_size)
  173. if "HMAC" in msg.annotations and hmac_key:
  174. if not secure_compare(msg.annotations["HMAC"], msg.hmac()):
  175. exc = errors.SecurityError("message hmac mismatch")
  176. exc.pyroMsg = msg
  177. raise exc
  178. elif ("HMAC" in msg.annotations) != bool(hmac_key):
  179. # Not allowed: message contains hmac but hmac_key is not set, or vice versa.
  180. err = "hmac key config not symmetric"
  181. log.warning(err)
  182. exc = errors.SecurityError(err)
  183. exc.pyroMsg = msg
  184. raise exc
  185. return msg
  186. def hmac(self):
  187. """returns the hmac of the data and the annotation chunk values (except HMAC chunk itself)"""
  188. mac = hmac.new(self.hmac_key, self.data, digestmod=hashlib.sha1)
  189. for k, v in sorted(self.annotations.items()): # note: sorted because we need fixed order to get the same hmac
  190. if k != "HMAC":
  191. mac.update(v)
  192. return mac.digest() if sys.platform != "cli" else bytes(mac.digest())
  193. @staticmethod
  194. def ping(pyroConnection, hmac_key=None):
  195. """Convenience method to send a 'ping' message and wait for the 'pong' response"""
  196. ping = Message(MSG_PING, b"ping", 42, 0, 0, hmac_key=hmac_key)
  197. pyroConnection.send(ping.to_bytes())
  198. Message.recv(pyroConnection, [MSG_PING])
  199. def decompress_if_needed(self):
  200. """Decompress the message data if it is compressed."""
  201. if self.flags & FLAGS_COMPRESSED:
  202. self.data = zlib.decompress(self.data)
  203. self.flags &= ~FLAGS_COMPRESSED
  204. self.data_size = len(self.data)
  205. return self
  206. try:
  207. from hmac import compare_digest as secure_compare
  208. except ImportError:
  209. # Python version doesn't have it natively, use a python fallback implementation
  210. import operator
  211. try:
  212. reduce
  213. except NameError:
  214. from functools import reduce
  215. def secure_compare(a, b):
  216. if type(a) != type(b):
  217. raise TypeError("arguments must both be same type")
  218. if len(a) != len(b):
  219. return False
  220. return reduce(operator.and_, map(operator.eq, a, b), True)