123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- """
- The pyro wire protocol message.
- Pyro - Python Remote Objects. Copyright by Irmen de Jong (irmen@razorvine.net).
- """
- import hashlib
- import hmac
- import struct
- import logging
- import sys
- import zlib
- from Pyro4 import errors, constants
- from Pyro4.configuration import config
- __all__ = ["Message", "secure_compare"]
- log = logging.getLogger("Pyro4.message")
- MSG_CONNECT = 1
- MSG_CONNECTOK = 2
- MSG_CONNECTFAIL = 3
- MSG_INVOKE = 4
- MSG_RESULT = 5
- MSG_PING = 6
- FLAGS_EXCEPTION = 1 << 0
- FLAGS_COMPRESSED = 1 << 1
- FLAGS_ONEWAY = 1 << 2
- FLAGS_BATCH = 1 << 3
- FLAGS_META_ON_CONNECT = 1 << 4
- FLAGS_ITEMSTREAMRESULT = 1 << 5
- FLAGS_KEEPSERIALIZED = 1 << 6
- class Message(object):
- """
- Pyro write protocol message.
- Wire messages contains of a fixed size header, an optional set of annotation chunks,
- and then the payload data. This class doesn't deal with the payload data:
- (de)serialization and handling of that data is done elsewhere.
- Annotation chunks are only parsed, except the 'HMAC' chunk: that is created
- and validated because it is used as a message digest.
- The header format is::
- 4 id ('PYRO')
- 2 protocol version
- 2 message type
- 2 message flags
- 2 sequence number
- 4 data length (i.e. 2 Gb data size limitation)
- 2 data serialization format (serializer id)
- 2 annotations length (total of all chunks, 0 if no annotation chunks present)
- 2 (reserved)
- 2 checksum
- After the header, zero or more annotation chunks may follow, of the format::
- 4 id (ASCII)
- 2 chunk length
- x annotation chunk databytes
- After that, the actual payload data bytes follow.
- The sequencenumber is used to check if response messages correspond to the
- actual request message. This prevents the situation where Pyro would perhaps return
- the response data from another remote call (which would not result in an error otherwise!)
- This could happen for instance if the socket data stream gets out of sync, perhaps due To
- some form of signal that interrupts I/O.
- The header checksum is a simple sum of the header fields to make reasonably sure
- that we are dealing with an actual correct PYRO protocol header and not some random
- data that happens to start with the 'PYRO' protocol identifier.
- Pyro now uses two annotation chunks that you should not touch yourself:
- 'HMAC' contains the hmac digest of the message data bytes and
- all of the annotation chunk data bytes (except those of the HMAC chunk itself).
- 'CORR' contains the correlation id (guid bytes)
- Other chunk names are free to use for custom purposes, but Pyro has the right
- to reserve more of them for internal use in the future.
- """
- __slots__ = ["type", "flags", "seq", "data", "data_size", "serializer_id", "annotations", "annotations_size", "hmac_key"]
- header_format = '!4sHHHHiHHHH'
- header_size = struct.calcsize(header_format)
- checksum_magic = 0x34E9
- def __init__(self, msgType, databytes, serializer_id, flags, seq, annotations=None, hmac_key=None):
- self.type = msgType
- self.flags = flags
- self.seq = seq
- self.data = databytes
- self.data_size = len(self.data)
- self.serializer_id = serializer_id
- self.annotations = dict(annotations or {})
- self.hmac_key = hmac_key
- if self.hmac_key:
- self.annotations["HMAC"] = self.hmac() # should be done last because it calculates hmac over other annotations
- self.annotations_size = sum([6 + len(v) for v in self.annotations.values()])
- if 0 < config.MAX_MESSAGE_SIZE < (self.data_size + self.annotations_size):
- raise errors.MessageTooLargeError("max message size exceeded (%d where max=%d)" %
- (self.data_size + self.annotations_size, config.MAX_MESSAGE_SIZE))
- def __repr__(self):
- return "<%s.%s at %x; type=%d flags=%d seq=%d datasize=%d #ann=%d>" %\
- (self.__module__, self.__class__.__name__, id(self), self.type, self.flags, self.seq, self.data_size, len(self.annotations))
- def to_bytes(self):
- """creates a byte stream containing the header followed by annotations (if any) followed by the data"""
- return self.__header_bytes() + self.__annotations_bytes() + self.data
- def __header_bytes(self):
- if not (0 <= self.data_size <= 0x7fffffff):
- raise ValueError("invalid message size (outside range 0..2Gb)")
- checksum = (self.type + constants.PROTOCOL_VERSION + self.data_size + self.annotations_size +
- self.serializer_id + self.flags + self.seq + self.checksum_magic) & 0xffff
- return struct.pack(self.header_format, b"PYRO", constants.PROTOCOL_VERSION, self.type, self.flags,
- self.seq, self.data_size, self.serializer_id, self.annotations_size, 0, checksum)
- def __annotations_bytes(self):
- if self.annotations:
- a = []
- for k, v in self.annotations.items():
- if len(k) != 4:
- raise errors.ProtocolError("annotation key must be of length 4")
- if sys.version_info >= (3, 0):
- k = k.encode("ASCII")
- a.append(struct.pack("!4sH", k, len(v)))
- a.append(v)
- return b"".join(a)
- return b""
- # Note: this 'chunked' way of sending is not used because it triggers Nagle's algorithm
- # on some systems (linux). This causes big delays, unless you change the socket option
- # TCP_NODELAY to disable the algorithm. What also works, is sending all the message bytes
- # in one go: connection.send(message.to_bytes()). This is what Pyro does.
- def send(self, connection):
- """send the message as bytes over the connection"""
- connection.send(self.__header_bytes())
- if self.annotations:
- connection.send(self.__annotations_bytes())
- connection.send(self.data)
- @classmethod
- def from_header(cls, headerData):
- """Parses a message header. Does not yet process the annotations chunks and message data."""
- if not headerData or len(headerData) != cls.header_size:
- raise errors.ProtocolError("header data size mismatch")
- tag, ver, msg_type, flags, seq, data_size, serializer_id, anns_size, _, checksum = struct.unpack(cls.header_format, headerData)
- if tag != b"PYRO" or ver != constants.PROTOCOL_VERSION:
- raise errors.ProtocolError("invalid data or unsupported protocol version")
- if checksum != (msg_type + ver + data_size + anns_size + flags + serializer_id + seq + cls.checksum_magic) & 0xffff:
- raise errors.ProtocolError("header checksum mismatch")
- msg = Message(msg_type, b"", serializer_id, flags, seq)
- msg.data_size = data_size
- msg.annotations_size = anns_size
- return msg
- @classmethod
- def recv(cls, connection, requiredMsgTypes=None, hmac_key=None):
- """
- Receives a pyro message from a given connection.
- Accepts the given message types (None=any, or pass a sequence).
- Also reads annotation chunks and the actual payload data.
- Validates a HMAC chunk if present.
- """
- msg = cls.from_header(connection.recv(cls.header_size))
- msg.hmac_key = hmac_key
- if 0 < config.MAX_MESSAGE_SIZE < (msg.data_size + msg.annotations_size):
- errorMsg = "max message size exceeded (%d where max=%d)" % (msg.data_size + msg.annotations_size, config.MAX_MESSAGE_SIZE)
- log.error("connection " + str(connection) + ": " + errorMsg)
- connection.close() # close the socket because at this point we can't return the correct seqnr for returning an errormsg
- exc = errors.MessageTooLargeError(errorMsg)
- exc.pyroMsg = msg
- raise exc
- if requiredMsgTypes and msg.type not in requiredMsgTypes:
- err = "invalid msg type %d received" % msg.type
- log.error(err)
- exc = errors.ProtocolError(err)
- exc.pyroMsg = msg
- raise exc
- if msg.annotations_size:
- # read annotation chunks
- annotations_data = connection.recv(msg.annotations_size)
- msg.annotations = {}
- i = 0
- while i < msg.annotations_size:
- anno, length = struct.unpack("!4sH", annotations_data[i:i + 6])
- if sys.version_info >= (3, 0):
- anno = anno.decode("ASCII")
- msg.annotations[anno] = annotations_data[i + 6:i + 6 + length]
- if sys.platform == "cli":
- msg.annotations[anno] = bytes(msg.annotations[anno])
- i += 6 + length
- # read data
- msg.data = connection.recv(msg.data_size)
- if "HMAC" in msg.annotations and hmac_key:
- if not secure_compare(msg.annotations["HMAC"], msg.hmac()):
- exc = errors.SecurityError("message hmac mismatch")
- exc.pyroMsg = msg
- raise exc
- elif ("HMAC" in msg.annotations) != bool(hmac_key):
- # Not allowed: message contains hmac but hmac_key is not set, or vice versa.
- err = "hmac key config not symmetric"
- log.warning(err)
- exc = errors.SecurityError(err)
- exc.pyroMsg = msg
- raise exc
- return msg
- def hmac(self):
- """returns the hmac of the data and the annotation chunk values (except HMAC chunk itself)"""
- mac = hmac.new(self.hmac_key, self.data, digestmod=hashlib.sha1)
- for k, v in sorted(self.annotations.items()): # note: sorted because we need fixed order to get the same hmac
- if k != "HMAC":
- mac.update(v)
- return mac.digest() if sys.platform != "cli" else bytes(mac.digest())
- @staticmethod
- def ping(pyroConnection, hmac_key=None):
- """Convenience method to send a 'ping' message and wait for the 'pong' response"""
- ping = Message(MSG_PING, b"ping", 42, 0, 0, hmac_key=hmac_key)
- pyroConnection.send(ping.to_bytes())
- Message.recv(pyroConnection, [MSG_PING])
- def decompress_if_needed(self):
- """Decompress the message data if it is compressed."""
- if self.flags & FLAGS_COMPRESSED:
- self.data = zlib.decompress(self.data)
- self.flags &= ~FLAGS_COMPRESSED
- self.data_size = len(self.data)
- return self
- try:
- from hmac import compare_digest as secure_compare
- except ImportError:
- # Python version doesn't have it natively, use a python fallback implementation
- import operator
- try:
- reduce
- except NameError:
- from functools import reduce
- def secure_compare(a, b):
- if type(a) != type(b):
- raise TypeError("arguments must both be same type")
- if len(a) != len(b):
- return False
- return reduce(operator.and_, map(operator.eq, a, b), True)
|