diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 66bf7c7049..7683676a43 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4714,7 +4714,7 @@ def _set_result(self, host, connection, pool, response): protocol = self.session.cluster.protocol_version info = self._custom_payload.get('tablets-routing-v1') ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))') - tablet_routing_info = ctype.from_binary(info, protocol) + tablet_routing_info = ctype(protocol).from_binary(info) first_token = tablet_routing_info[0] last_token = tablet_routing_info[1] tablet_replicas = tablet_routing_info[2] diff --git a/cassandra/column_encryption/_policies.py b/cassandra/column_encryption/_policies.py index ef8097bfbd..7170ee9675 100644 --- a/cassandra/column_encryption/_policies.py +++ b/cassandra/column_encryption/_policies.py @@ -114,7 +114,7 @@ def encode_and_encrypt(self, coldesc, obj): coldata = self.coldata.get(coldesc) if not coldata: raise ValueError("Could not find ColData for ColDesc %s".format(coldesc)) - return self.encrypt(coldesc, coldata.type.serialize(obj, None)) + return self.encrypt(coldesc, coldata.type(None).serialize(obj)) def cache_info(self): return AES256ColumnEncryptionPolicy._build_cipher.cache_info() diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index bc00001666..2335c5c69e 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -948,7 +948,7 @@ def _transform_column(col_name, col_obj): key_cols = [c for c in partition_keys.values()] partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) key_cql_types = [c.cql_type for c in key_cols] - key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) + key_serializer = staticmethod(lambda parts, proto_version: [t(proto_version).to_binary(p) for t, p in zip(key_cql_types, parts)]) else: partition_key_index = {} key_serializer = staticmethod(lambda parts, proto_version: None) diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index e36c48563c..5e1ff85af8 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -287,11 +287,13 @@ class _CassandraType(object, metaclass=CassandraTypeType): of EmptyValue) will be returned. """ + def __init__(self, protocol_version=None): + self.protocol_version = protocol_version + def __repr__(self): return '<%s>' % (self.cql_parameterized_type()) - @classmethod - def from_binary(cls, byts, protocol_version): + def from_binary(self, byts): """ Deserialize a bytestring into a value. See the deserialize() method for more information. This method differs in that if None or the empty @@ -299,21 +301,19 @@ def from_binary(cls, byts, protocol_version): """ if byts is None: return None - elif len(byts) == 0 and not cls.empty_binary_ok: - return EMPTY if cls.support_empty_values else None - return cls.deserialize(byts, protocol_version) + elif len(byts) == 0 and not self.empty_binary_ok: + return EMPTY if self.support_empty_values else None + return self.deserialize(byts) - @classmethod - def to_binary(cls, val, protocol_version): + def to_binary(self, val): """ Serialize a value into a bytestring. See the serialize() method for more information. This method differs in that if None is passed in, the result is the empty string. """ - return b'' if val is None else cls.serialize(val, protocol_version) + return b'' if val is None else self.serialize(val) - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): """ Given a bytestring, deserialize into a value according to the protocol for this type. Note that this does not create a new instance of this @@ -322,8 +322,7 @@ def deserialize(byts, protocol_version): """ return byts - @staticmethod - def serialize(val, protocol_version): + def serialize(self, val): """ Given a value appropriate for this class, serialize it according to the protocol for this type and return the corresponding bytestring. @@ -415,22 +414,19 @@ class BytesType(_CassandraType): typename = 'blob' empty_binary_ok = True - @staticmethod - def serialize(val, protocol_version): + def serialize(self, val): return bytes(val) class DecimalType(_CassandraType): typename = 'decimal' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) return Decimal('%de%d' % (unscaled, -scale)) - @staticmethod - def serialize(dec, protocol_version): + def serialize(self, dec): try: sign, digits, exponent = dec.as_tuple() except AttributeError: @@ -449,12 +445,10 @@ def serialize(dec, protocol_version): class UUIDType(_CassandraType): typename = 'uuid' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return UUID(bytes=byts) - @staticmethod - def serialize(uuid, protocol_version): + def serialize(self, uuid): try: return uuid.bytes except AttributeError: @@ -467,12 +461,10 @@ def serial_size(cls): class BooleanType(_CassandraType): typename = 'boolean' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return bool(int8_unpack(byts)) - @staticmethod - def serialize(truth, protocol_version): + def serialize(self, truth): return int8_pack(truth) @classmethod @@ -482,12 +474,10 @@ def serial_size(cls): class ByteType(_CassandraType): typename = 'tinyint' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return int8_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return int8_pack(byts) @@ -495,12 +485,10 @@ class AsciiType(_CassandraType): typename = 'ascii' empty_binary_ok = True - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return byts.decode('ascii') - @staticmethod - def serialize(var, protocol_version): + def serialize(self, var): try: return var.encode('ascii') except UnicodeDecodeError: @@ -510,12 +498,10 @@ def serialize(var, protocol_version): class FloatType(_CassandraType): typename = 'float' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return float_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return float_pack(byts) @classmethod @@ -525,12 +511,10 @@ def serial_size(cls): class DoubleType(_CassandraType): typename = 'double' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return double_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return double_pack(byts) @classmethod @@ -540,12 +524,10 @@ def serial_size(cls): class LongType(_CassandraType): typename = 'bigint' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return int64_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return int64_pack(byts) @classmethod @@ -555,12 +537,10 @@ def serial_size(cls): class Int32Type(_CassandraType): typename = 'int' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return int32_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return int32_pack(byts) @classmethod @@ -570,20 +550,17 @@ def serial_size(cls): class IntegerType(_CassandraType): typename = 'varint' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return varint_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return varint_pack(byts) class InetAddressType(_CassandraType): typename = 'inet' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): if len(byts) == 16: return util.inet_ntop(socket.AF_INET6, byts) else: @@ -591,8 +568,7 @@ def deserialize(byts, protocol_version): # since we've already determined the AF return socket.inet_ntoa(byts) - @staticmethod - def serialize(addr, protocol_version): + def serialize(self, addr): try: if ':' in addr: return util.inet_pton(socket.AF_INET6, addr) @@ -640,13 +616,11 @@ def interpret_datestring(val): else: raise ValueError("can't interpret %r as a date" % (val,)) - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): timestamp = int64_unpack(byts) / 1000.0 return util.datetime_from_timestamp(timestamp) - @staticmethod - def serialize(v, protocol_version): + def serialize(self, v): try: # v is datetime timestamp_seconds = calendar.timegm(v.utctimetuple()) @@ -676,12 +650,10 @@ class TimeUUIDType(DateType): def my_timestamp(self): return util.unix_time_from_uuid1(self.val) - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return UUID(bytes=byts) - @staticmethod - def serialize(timeuuid, protocol_version): + def serialize(self, timeuuid): try: return timeuuid.bytes except AttributeError: @@ -700,13 +672,11 @@ class SimpleDateType(_CassandraType): # range (2^31). EPOCH_OFFSET_DAYS = 2 ** 31 - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): days = uint32_unpack(byts) - SimpleDateType.EPOCH_OFFSET_DAYS return util.Date(days) - @staticmethod - def serialize(val, protocol_version): + def serialize(self, val): try: days = val.days_from_epoch except AttributeError: @@ -722,12 +692,10 @@ def serialize(val, protocol_version): class ShortType(_CassandraType): typename = 'smallint' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return int16_unpack(byts) - @staticmethod - def serialize(byts, protocol_version): + def serialize(self, byts): return int16_pack(byts) class TimeType(_CassandraType): @@ -739,12 +707,10 @@ class TimeType(_CassandraType): #def serial_size(cls): # return 8 - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return util.Time(int64_unpack(byts)) - @staticmethod - def serialize(val, protocol_version): + def serialize(self, val): try: nano = val.nanosecond_time except AttributeError: @@ -755,13 +721,11 @@ def serialize(val, protocol_version): class DurationType(_CassandraType): typename = 'duration' - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): months, days, nanoseconds = vints_unpack(byts) return util.Duration(months, days, nanoseconds) - @staticmethod - def serialize(duration, protocol_version): + def serialize(self, duration): try: m, d, n = duration.months, duration.days, duration.nanoseconds except AttributeError: @@ -773,12 +737,10 @@ class UTF8Type(_CassandraType): typename = 'text' empty_binary_ok = True - @staticmethod - def deserialize(byts, protocol_version): + def deserialize(self, byts): return byts.decode('utf8') - @staticmethod - def serialize(ustr, protocol_version): + def serialize(self, ustr): try: return ustr.encode('utf-8') except UnicodeDecodeError: @@ -793,30 +755,31 @@ class VarcharType(UTF8Type): class _ParameterizedType(_CassandraType): num_subtypes = 'UNKNOWN' - @classmethod - def deserialize(cls, byts, protocol_version): - if not cls.subtypes: + def __init__(self, protocol_version=None): + super(_ParameterizedType, self).__init__(protocol_version) + inner_proto = max(3, protocol_version) if protocol_version is not None else None + self.subtype_instances = [s(inner_proto) for s in self.subtypes] + + def deserialize(self, byts): + if not self.subtypes: raise NotImplementedError("can't deserialize unparameterized %s" - % cls.typename) - return cls.deserialize_safe(byts, protocol_version) + % self.typename) + return self.deserialize_safe(byts) - @classmethod - def serialize(cls, val, protocol_version): - if not cls.subtypes: + def serialize(self, val): + if not self.subtypes: raise NotImplementedError("can't serialize unparameterized %s" - % cls.typename) - return cls.serialize_safe(val, protocol_version) + % self.typename) + return self.serialize_safe(val) class _SimpleParameterizedType(_ParameterizedType): - @classmethod - def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + def deserialize_safe(self, byts): + subtype, = self.subtype_instances length = 4 numelements = int32_unpack(byts[:length]) p = length result = [] - inner_proto = max(3, protocol_version) for _ in range(numelements): itemlen = int32_unpack(byts[p:p + length]) p += length @@ -825,23 +788,21 @@ def deserialize_safe(cls, byts, protocol_version): else: item = byts[p:p + itemlen] p += itemlen - result.append(subtype.from_binary(item, inner_proto)) - return cls.adapter(result) + result.append(subtype.from_binary(item)) + return self.adapter(result) - @classmethod - def serialize_safe(cls, items, protocol_version): + def serialize_safe(self, items): if isinstance(items, str): raise TypeError("Received a string for a type that expects a sequence") - subtype, = cls.subtypes + subtype, = self.subtype_instances buf = io.BytesIO() buf.write(int32_pack(len(items))) - inner_proto = max(3, protocol_version) for item in items: if item is None: buf.write(int32_pack(-1)) else: - itembytes = subtype.to_binary(item, inner_proto) + itembytes = subtype.to_binary(item) buf.write(int32_pack(len(itembytes))) buf.write(itembytes) return buf.getvalue() @@ -863,14 +824,12 @@ class MapType(_ParameterizedType): typename = 'map' num_subtypes = 2 - @classmethod - def deserialize_safe(cls, byts, protocol_version): - key_type, value_type = cls.subtypes + def deserialize_safe(self, byts): + key_type, value_type = self.subtype_instances length = 4 numelements = int32_unpack(byts[:length]) p = length - themap = util.OrderedMapSerializedKey(key_type, protocol_version) - inner_proto = max(3, protocol_version) + themap = util.OrderedMapSerializedKey(key_type) for _ in range(numelements): key_len = int32_unpack(byts[p:p + length]) p += length @@ -880,7 +839,7 @@ def deserialize_safe(cls, byts, protocol_version): else: keybytes = byts[p:p + key_len] p += key_len - key = key_type.from_binary(keybytes, inner_proto) + key = key_type.from_binary(keybytes) val_len = int32_unpack(byts[p:p + length]) p += length @@ -889,30 +848,28 @@ def deserialize_safe(cls, byts, protocol_version): else: valbytes = byts[p:p + val_len] p += val_len - val = value_type.from_binary(valbytes, inner_proto) + val = value_type.from_binary(valbytes) themap._insert_unchecked(key, keybytes, val) return themap - @classmethod - def serialize_safe(cls, themap, protocol_version): - key_type, value_type = cls.subtypes + def serialize_safe(self, themap): + key_type, value_type = self.subtype_instances buf = io.BytesIO() buf.write(int32_pack(len(themap))) try: items = themap.items() except AttributeError: raise TypeError("Got a non-map object for a map value") - inner_proto = max(3, protocol_version) for key, val in items: if key is not None: - keybytes = key_type.to_binary(key, inner_proto) + keybytes = key_type.to_binary(key) buf.write(int32_pack(len(keybytes))) buf.write(keybytes) else: buf.write(int32_pack(-1)) if val is not None: - valbytes = value_type.to_binary(val, inner_proto) + valbytes = value_type.to_binary(val) buf.write(int32_pack(len(valbytes))) buf.write(valbytes) else: @@ -923,12 +880,10 @@ def serialize_safe(cls, themap, protocol_version): class TupleType(_ParameterizedType): typename = 'tuple' - @classmethod - def deserialize_safe(cls, byts, protocol_version): - proto_version = max(3, protocol_version) + def deserialize_safe(self, byts): p = 0 values = [] - for col_type in cls.subtypes: + for col_type in self.subtype_instances: if p == len(byts): break itemlen = int32_unpack(byts[p:p + 4]) @@ -940,25 +895,23 @@ def deserialize_safe(cls, byts, protocol_version): item = None # collections inside UDTs are always encoded with at least the # version 3 format - values.append(col_type.from_binary(item, proto_version)) + values.append(col_type.from_binary(item)) - if len(values) < len(cls.subtypes): - nones = [None] * (len(cls.subtypes) - len(values)) + if len(values) < len(self.subtypes): + nones = [None] * (len(self.subtypes) - len(values)) values = values + nones return tuple(values) - @classmethod - def serialize_safe(cls, val, protocol_version): - if len(val) > len(cls.subtypes): + def serialize_safe(self, val): + if len(val) > len(self.subtypes): raise ValueError("Expected %d items in a tuple, but got %d: %s" % - (len(cls.subtypes), len(val), val)) + (len(self.subtypes), len(val), val)) - proto_version = max(3, protocol_version) buf = io.BytesIO() - for item, subtype in zip(val, cls.subtypes): + for item, subtype in zip(val, self.subtype_instances): if item is not None: - packed_item = subtype.to_binary(item, proto_version) + packed_item = subtype.to_binary(item) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: @@ -1011,31 +964,28 @@ def apply_parameters(cls, subtypes, names): def cql_parameterized_type(cls): return "frozen<%s>" % (cls.typename,) - @classmethod - def deserialize_safe(cls, byts, protocol_version): - values = super(UserType, cls).deserialize_safe(byts, protocol_version) - if cls.mapped_class: - return cls.mapped_class(**dict(zip(cls.fieldnames, values))) - elif cls.tuple_type: - return cls.tuple_type(*values) + def deserialize_safe(self, byts): + values = super(UserType, self).deserialize_safe(byts) + if self.mapped_class: + return self.mapped_class(**dict(zip(self.fieldnames, values))) + elif self.tuple_type: + return self.tuple_type(*values) else: return tuple(values) - @classmethod - def serialize_safe(cls, val, protocol_version): - proto_version = max(3, protocol_version) + def serialize_safe(self, val): buf = io.BytesIO() - for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)): + for i, (fieldname, subtype) in enumerate(zip(self.fieldnames, self.subtype_instances)): # first treat as a tuple, else by custom type try: item = val[i] except TypeError: item = getattr(val, fieldname, None) if item is None and not hasattr(val, fieldname): - log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}") + log.warning(f"field {fieldname} is part of the UDT {self.typename} but is not present in the value {val}") if item is not None: - packed_item = subtype.to_binary(item, proto_version) + packed_item = subtype.to_binary(item) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: @@ -1084,10 +1034,9 @@ def cql_parameterized_type(cls): typestring = cls.cass_parameterized_type(full=True) return "'%s'" % (typestring,) - @classmethod - def deserialize_safe(cls, byts, protocol_version): + def deserialize_safe(self, byts): result = [] - for subtype in cls.subtypes: + for subtype in self.subtype_instances: if not byts: # CompositeType can have missing elements at the end break @@ -1097,7 +1046,7 @@ def deserialize_safe(cls, byts, protocol_version): # skip element length, element, and the EOC (one byte) byts = byts[2 + element_length + 1:] - result.append(subtype.from_binary(element, protocol_version)) + result.append(subtype.from_binary(element)) return tuple(result) @@ -1124,30 +1073,26 @@ class ReversedType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.ReversedType" num_subtypes = 1 - @classmethod - def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes - return subtype.from_binary(byts, protocol_version) + def deserialize_safe(self, byts): + subtype, = self.subtype_instances + return subtype.from_binary(byts) - @classmethod - def serialize_safe(cls, val, protocol_version): - subtype, = cls.subtypes - return subtype.to_binary(val, protocol_version) + def serialize_safe(self, val): + subtype, = self.subtype_instances + return subtype.to_binary(val) class FrozenType(_ParameterizedType): typename = "frozen" num_subtypes = 1 - @classmethod - def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes - return subtype.from_binary(byts, protocol_version) + def deserialize_safe(self, byts): + subtype, = self.subtype_instances + return subtype.from_binary(byts) - @classmethod - def serialize_safe(cls, val, protocol_version): - subtype, = cls.subtypes - return subtype.to_binary(val, protocol_version) + def serialize_safe(self, val): + subtype, = self.subtype_instances + return subtype.to_binary(val) def is_counter_type(t): @@ -1181,12 +1126,10 @@ class PointType(CassandraType): _type = struct.pack('[[]] type_ = int8_unpack(byts[0:1]) @@ -1349,12 +1287,12 @@ def deserialize(cls, byts, protocol_version): if time0 is not None: date_range_bound0 = util.DateRangeBound( time0, - cls._decode_precision(precision0) + self._decode_precision(precision0) ) if time1 is not None: date_range_bound1 = util.DateRangeBound( time1, - cls._decode_precision(precision1) + self._decode_precision(precision1) ) if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE): @@ -1375,8 +1313,7 @@ def deserialize(cls, byts, protocol_version): return util.DateRange(value=util.OPEN_BOUND) raise ValueError('Could not deserialize %r' % (byts,)) - @classmethod - def serialize(cls, v, protocol_version): + def serialize(self, v): buf = io.BytesIO() bound_kind, bounds = None, () @@ -1385,7 +1322,7 @@ def serialize(cls, v, protocol_version): except AttributeError: raise ValueError( '%s.serialize expects an object with a value attribute; got' - '%r' % (cls.__name__, v) + '%r' % (self.__class__.__name__, v) ) if value is None: @@ -1394,7 +1331,7 @@ def serialize(cls, v, protocol_version): except AttributeError: raise ValueError( '%s.serialize expects an object with lower_bound and ' - 'upper_bound attributes; got %r' % (cls.__name__, v) + 'upper_bound attributes; got %r' % (self.__class__.__name__, v) ) if lower_bound == util.OPEN_BOUND and upper_bound == util.OPEN_BOUND: bound_kind = BoundKind.BOTH_OPEN_RANGE @@ -1422,7 +1359,7 @@ def serialize(cls, v, protocol_version): buf.write(int8_pack(BoundKind.to_int(bound_kind))) for bound in bounds: buf.write(int64_pack(bound.milliseconds)) - buf.write(int8_pack(cls._encode_precision(bound.precision))) + buf.write(int8_pack(self._encode_precision(bound.precision))) return buf.getvalue() @@ -1431,6 +1368,10 @@ class VectorType(_CassandraType): vector_size = 0 subtype = None + def __init__(self, protocol_version=None): + super(VectorType, self).__init__(protocol_version) + self.subtype_instance = self.subtype(protocol_version) + @classmethod def serial_size(cls): serialized_size = cls.subtype.serial_size() @@ -1443,25 +1384,24 @@ def apply_parameters(cls, params, names): vsize = params[1] return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype}) - @classmethod - def deserialize(cls, byts, protocol_version): - serialized_size = cls.subtype.serial_size() + def deserialize(self, byts): + serialized_size = self.subtype.serial_size() if serialized_size is not None: - expected_byte_size = serialized_size * cls.vector_size + expected_byte_size = serialized_size * self.vector_size if len(byts) != expected_byte_size: raise ValueError( "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\ - .format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts))) - indexes = (serialized_size * x for x in range(0, cls.vector_size)) - return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] + .format(self.subtype.typename, self.vector_size, expected_byte_size, len(byts))) + indexes = (serialized_size * x for x in range(0, self.vector_size)) + return [self.subtype_instance.deserialize(byts[idx:idx + serialized_size]) for idx in indexes] idx = 0 rv = [] - while (len(rv) < cls.vector_size): + while (len(rv) < self.vector_size): try: size, bytes_read = uvint_unpack(byts[idx:]) idx += bytes_read - rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) + rv.append(self.subtype_instance.deserialize(byts[idx:idx + size])) idx += size except: raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\ @@ -1472,18 +1412,17 @@ def deserialize(cls, byts, protocol_version): raise ValueError("Additional bytes remaining after vector deserialization completed") return rv - @classmethod - def serialize(cls, v, protocol_version): + def serialize(self, v): v_length = len(v) - if cls.vector_size != v_length: + if self.vector_size != v_length: raise ValueError( "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\ - .format(cls.vector_size, cls.subtype.typename, v_length)) + .format(self.vector_size, self.subtype.typename, v_length)) - serialized_size = cls.subtype.serial_size() + serialized_size = self.subtype.serial_size() buf = io.BytesIO() for item in v: - item_bytes = cls.subtype.serialize(item, protocol_version) + item_bytes = self.subtype_instance.serialize(item) if serialized_size is None: buf.write(uvint_pack(len(item_bytes))) buf.write(item_bytes) diff --git a/cassandra/deserializers.pxd b/cassandra/deserializers.pxd index 7b307226ad..ccc373ad57 100644 --- a/cassandra/deserializers.pxd +++ b/cassandra/deserializers.pxd @@ -17,6 +17,7 @@ from cassandra.buffer cimport Buffer cdef class Deserializer: # The cqltypes._CassandraType corresponding to this deserializer cdef object cqltype + cdef int protocol_version # String may be empty, whereas other values may not be. # Other values may be NULL, in which case the integer length @@ -26,18 +27,17 @@ cdef class Deserializer: # paragraph 6) cdef bint empty_binary_ok - cdef deserialize(self, Buffer *buf, int protocol_version) - # cdef deserialize(self, CString byts, protocol_version) + cdef deserialize(self, Buffer *buf) + # cdef deserialize(self, CString byts) cdef inline object from_binary(Deserializer deserializer, - Buffer *buf, - int protocol_version): + Buffer *buf): if buf.size < 0: return None elif buf.size == 0 and not deserializer.empty_binary_ok: return _ret_empty(deserializer, buf.size) else: - return deserializer.deserialize(buf, protocol_version) + return deserializer.deserialize(buf) cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 7c256674b0..bd3f2879ca 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -32,16 +32,17 @@ from cassandra import util cdef class Deserializer: """Cython-based deserializer class for a cqltype""" - def __init__(self, cqltype): + def __init__(self, cqltype, protocol_version): self.cqltype = cqltype + self.protocol_version = protocol_version self.empty_binary_ok = cqltype.empty_binary_ok - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): raise NotImplementedError cdef class DesBytesType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return b"" return to_bytes(buf) @@ -50,14 +51,14 @@ cdef class DesBytesType(Deserializer): # It is switched in by simply overwriting DesBytesType: # deserializers.DesBytesType = deserializers.DesBytesTypeByteArray cdef class DesBytesTypeByteArray(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return bytearray() return bytearray(buf.ptr[:buf.size]) # TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html cdef class DesDecimalType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef Buffer varint_buf slice_buffer(buf, &varint_buf, 4, buf.size - 4) @@ -68,56 +69,56 @@ cdef class DesDecimalType(Deserializer): cdef class DesUUIDType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return UUID(bytes=to_bytes(buf)) cdef class DesBooleanType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if unpack_num[int8_t](buf): return True return False cdef class DesByteType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int8_t](buf) cdef class DesAsciiType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return "" return to_bytes(buf).decode('ascii') cdef class DesFloatType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[float](buf) cdef class DesDoubleType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[double](buf) cdef class DesLongType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int64_t](buf) cdef class DesInt32Type(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int32_t](buf) cdef class DesIntegerType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return varint_unpack(buf) cdef class DesInetAddressType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef bytes byts = to_bytes(buf) # TODO: optimize inet_ntop, inet_ntoa @@ -134,7 +135,7 @@ cdef class DesCounterColumnType(DesLongType): cdef class DesDateType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef double timestamp = unpack_num[int64_t](buf) / 1000.0 return datetime_from_timestamp(timestamp) @@ -144,7 +145,7 @@ cdef class TimestampType(DesDateType): cdef class TimeUUIDType(DesDateType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return UUID(bytes=to_bytes(buf)) @@ -154,23 +155,23 @@ cdef class TimeUUIDType(DesDateType): EPOCH_OFFSET_DAYS = 2 ** 31 cdef class DesSimpleDateType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): days = unpack_num[uint32_t](buf) - EPOCH_OFFSET_DAYS return util.Date(days) cdef class DesShortType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int16_t](buf) cdef class DesTimeType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return util.Time(unpack_num[int64_t](buf)) cdef class DesUTF8Type(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return "" cdef val = to_bytes(buf) @@ -187,19 +188,19 @@ cdef class _DesParameterizedType(Deserializer): cdef Deserializer[::1] deserializers cdef Py_ssize_t subtypes_len - def __init__(self, cqltype): - super().__init__(cqltype) + def __init__(self, cqltype, protocol_version): + super().__init__(cqltype, protocol_version) self.subtypes = cqltype.subtypes - self.deserializers = make_deserializers(cqltype.subtypes) + self.deserializers = make_deserializers(cqltype.subtypes, protocol_version) self.subtypes_len = len(self.subtypes) cdef class _DesSingleParamType(_DesParameterizedType): cdef Deserializer deserializer - def __init__(self, cqltype): + def __init__(self, cqltype, protocol_version): assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes - super().__init__(cqltype) + super().__init__(cqltype, protocol_version) self.deserializer = self.deserializers[0] @@ -207,78 +208,58 @@ cdef class _DesSingleParamType(_DesParameterizedType): # List and set deserialization cdef class DesListType(_DesSingleParamType): - cdef deserialize(self, Buffer *buf, int protocol_version): - cdef uint16_t v2_and_below = 2 - cdef int32_t v3_and_above = 3 - - if protocol_version >= 3: - result = _deserialize_list_or_set[int32_t]( - v3_and_above, buf, protocol_version, self.deserializer) - else: - result = _deserialize_list_or_set[uint16_t]( - v2_and_below, buf, protocol_version, self.deserializer) + cdef deserialize(self, Buffer *buf): + result = _deserialize_list_or_set(buf, self.deserializer) return result cdef class DesSetType(DesListType): - cdef deserialize(self, Buffer *buf, int protocol_version): - return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) + cdef deserialize(self, Buffer *buf): + return util.sortedset(DesListType.deserialize(self, buf)) -ctypedef fused itemlen_t: - uint16_t # protocol <= v2 - int32_t # protocol >= v3 - -cdef list _deserialize_list_or_set(itemlen_t dummy_version, - Buffer *buf, int protocol_version, +cdef list _deserialize_list_or_set(Buffer *buf, Deserializer deserializer): """ Deserialize a list or set. - - The 'dummy' parameter is needed to make fused types work, so that - we can specialize on the protocol version. """ cdef Buffer itemlen_buf cdef Buffer elem_buf - cdef itemlen_t numelements + cdef int32_t numelements cdef int offset cdef list result = [] - _unpack_len[itemlen_t](buf, 0, &numelements) - offset = sizeof(itemlen_t) - protocol_version = max(3, protocol_version) + _unpack_len(buf, 0, &numelements) + offset = sizeof(int32_t) for _ in range(numelements): - subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version) - result.append(from_binary(deserializer, &elem_buf, protocol_version)) + subelem(buf, &elem_buf, &offset) + result.append(from_binary(deserializer, &elem_buf)) return result cdef inline int subelem( - Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1: + Buffer *buf, Buffer *elem_buf, int* offset) except -1: """ Read the next element from the buffer: first read the size (in bytes) of the element, then fill elem_buf with a newly sliced buffer of this size (and the right offset). """ - cdef itemlen_t elemlen + cdef int32_t elemlen - _unpack_len[itemlen_t](buf, offset[0], &elemlen) - offset[0] += sizeof(itemlen_t) + _unpack_len(buf, offset[0], &elemlen) + offset[0] += sizeof(int32_t) slice_buffer(buf, elem_buf, offset[0], elemlen) offset[0] += elemlen return 0 -cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1: +cdef int _unpack_len(Buffer *buf, int offset, int32_t *output) except -1: cdef Buffer itemlen_buf - slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t)) + slice_buffer(buf, &itemlen_buf, offset, sizeof(int32_t)) - if itemlen_t is uint16_t: - output[0] = unpack_num[uint16_t](&itemlen_buf) - else: - output[0] = unpack_num[int32_t](&itemlen_buf) + output[0] = unpack_num[int32_t](&itemlen_buf) return 0 @@ -289,50 +270,40 @@ cdef class DesMapType(_DesParameterizedType): cdef Deserializer key_deserializer, val_deserializer - def __init__(self, cqltype): - super().__init__(cqltype) + def __init__(self, cqltype, protocol_version): + super().__init__(cqltype, protocol_version) self.key_deserializer = self.deserializers[0] self.val_deserializer = self.deserializers[1] - cdef deserialize(self, Buffer *buf, int protocol_version): - cdef uint16_t v2_and_below = 0 - cdef int32_t v3_and_above = 0 + cdef deserialize(self, Buffer *buf): key_type, val_type = self.cqltype.subtypes - if protocol_version >= 3: - result = _deserialize_map[int32_t]( - v3_and_above, buf, protocol_version, - self.key_deserializer, self.val_deserializer, - key_type, val_type) - else: - result = _deserialize_map[uint16_t]( - v2_and_below, buf, protocol_version, - self.key_deserializer, self.val_deserializer, - key_type, val_type) + result = _deserialize_map( + buf, + self.key_deserializer, self.val_deserializer, + key_type(self.key_deserializer.protocol_version), val_type) return result -cdef _deserialize_map(itemlen_t dummy_version, - Buffer *buf, int protocol_version, +cdef _deserialize_map(Buffer *buf, Deserializer key_deserializer, Deserializer val_deserializer, - key_type, val_type): + key_type_instance, val_type): cdef Buffer key_buf, val_buf cdef Buffer itemlen_buf - cdef itemlen_t numelements + cdef int32_t numelements cdef int offset cdef list result = [] - _unpack_len[itemlen_t](buf, 0, &numelements) - offset = sizeof(itemlen_t) - themap = util.OrderedMapSerializedKey(key_type, protocol_version) - protocol_version = max(3, protocol_version) + _unpack_len(buf, 0, &numelements) + offset = sizeof(int32_t) + themap = util.OrderedMapSerializedKey(key_type_instance) for _ in range(numelements): - subelem[itemlen_t](buf, &key_buf, &offset, dummy_version) - subelem[itemlen_t](buf, &val_buf, &offset, numelements) - key = from_binary(key_deserializer, &key_buf, protocol_version) - val = from_binary(val_deserializer, &val_buf, protocol_version) + subelem(buf, &key_buf, &offset) + subelem(buf, &val_buf, &offset) + key = from_binary(key_deserializer, &key_buf) + val = from_binary(val_deserializer, &val_buf) themap._insert_unchecked(key, to_bytes(&key_buf), val) return themap @@ -343,7 +314,7 @@ cdef class DesTupleType(_DesParameterizedType): # TODO: Use TupleRowParser to parse these tuples - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef Py_ssize_t i, p cdef int32_t itemlen cdef tuple res = tuple_new(self.subtypes_len) @@ -351,10 +322,6 @@ cdef class DesTupleType(_DesParameterizedType): cdef Buffer itemlen_buf cdef Deserializer deserializer - # collections inside UDTs are always encoded with at least the - # version 3 format - protocol_version = max(3, protocol_version) - p = 0 values = [] for i in range(self.subtypes_len): @@ -368,7 +335,7 @@ cdef class DesTupleType(_DesParameterizedType): p += itemlen deserializer = self.deserializers[i] - item = from_binary(deserializer, &item_buf, protocol_version) + item = from_binary(deserializer, &item_buf) tuple_set(res, i, item) @@ -376,9 +343,9 @@ cdef class DesTupleType(_DesParameterizedType): cdef class DesUserType(DesTupleType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): typ = self.cqltype - values = DesTupleType.deserialize(self, buf, protocol_version) + values = DesTupleType.deserialize(self, buf) if typ.mapped_class: return typ.mapped_class(**dict(zip(typ.fieldnames, values))) elif typ.tuple_type: @@ -388,7 +355,7 @@ cdef class DesUserType(DesTupleType): cdef class DesCompositeType(_DesParameterizedType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef Py_ssize_t i, idx, start cdef Buffer elem_buf cdef int16_t element_length @@ -413,7 +380,7 @@ cdef class DesCompositeType(_DesParameterizedType): slice_buffer(buf, &elem_buf, 2, element_length) deserializer = self.deserializers[i] - item = from_binary(deserializer, &elem_buf, protocol_version) + item = from_binary(deserializer, &elem_buf) tuple_set(res, i, item) # skip element length, element, and the EOC (one byte) @@ -427,13 +394,13 @@ DesDynamicCompositeType = DesCompositeType cdef class DesReversedType(_DesSingleParamType): - cdef deserialize(self, Buffer *buf, int protocol_version): - return from_binary(self.deserializer, buf, protocol_version) + cdef deserialize(self, Buffer *buf): + return from_binary(self.deserializer, buf) cdef class DesFrozenType(_DesSingleParamType): - cdef deserialize(self, Buffer *buf, int protocol_version): - return from_binary(self.deserializer, buf, protocol_version) + cdef deserialize(self, Buffer *buf): + return from_binary(self.deserializer, buf) #-------------------------------------------------------------------------- @@ -456,9 +423,14 @@ cdef class GenericDeserializer(Deserializer): """ Wrap a generic datatype for deserialization """ + cdef object cqltype_instance + + def __init__(self, cqltype, protocol_version): + super().__init__(cqltype, protocol_version) + self.cqltype_instance = cqltype(protocol_version) - cdef deserialize(self, Buffer *buf, int protocol_version): - return self.cqltype.deserialize(to_bytes(buf), protocol_version) + cdef deserialize(self, Buffer *buf): + return self.cqltype_instance.deserialize(to_bytes(buf)) def __repr__(self): return "GenericDeserializer(%s)" % (self.cqltype,) @@ -466,15 +438,15 @@ cdef class GenericDeserializer(Deserializer): #-------------------------------------------------------------------------- # Helper utilities -def make_deserializers(cqltypes): +def make_deserializers(cqltypes, protocol_version): """Create an array of Deserializers for each given cqltype in cqltypes""" cdef Deserializer[::1] deserializers - return obj_array([find_deserializer(ct) for ct in cqltypes]) + return obj_array([find_deserializer(ct, protocol_version) for ct in cqltypes]) cdef dict classes = globals() -cpdef Deserializer find_deserializer(cqltype): +cpdef Deserializer find_deserializer(cqltype, protocol_version): """Find a deserializer for a cqltype""" name = 'Des' + cqltype.__name__ @@ -503,7 +475,7 @@ cpdef Deserializer find_deserializer(cqltype): else: cls = GenericDeserializer - return cls(cqltype) + return cls(cqltype, protocol_version) def obj_array(list objs): diff --git a/cassandra/metadata.py b/cassandra/metadata.py index bbfaf2605b..9e55bd0bc5 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -2204,7 +2204,7 @@ def _build_aggregate(cls, aggregate_row): cass_state_type = types.lookup_casstype(aggregate_row['state_type']) initial_condition = aggregate_row['initcond'] if initial_condition is not None: - initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) + initial_condition = _encoder.cql_encode_all_types(cass_state_type(3).deserialize(initial_condition)) state_type = _cql_from_cass_type(cass_state_type) return_type = cls._schema_type_to_cql(aggregate_row['return_type']) return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], @@ -3475,7 +3475,7 @@ def group_keys_by_replica(session, keyspace, table, keys): distance = cluster._default_load_balancing_policy.distance for key in keys: - serialized_key = [serializer.serialize(pk, cluster.protocol_version) + serialized_key = [serializer(cluster.protocol_version).serialize(pk) for serializer, pk in zip(serializers, key)] if len(serialized_key) == 1: routing_key = serialized_key[0] diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 030c2c65c7..3ce3809285 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -156,7 +156,7 @@ cdef inline int unpack_row( if arr.is_object: deserializer = desc.deserializers[i] - val = from_binary(deserializer, &buf, desc.protocol_version) + val = from_binary(deserializer, &buf) Py_INCREF(val) ( arr.buf_ptr)[0] = val elif buf.size >= 0: diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index cf43771dd7..5fec91bc2b 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -80,10 +80,10 @@ cdef class TupleRowParser(RowParser): col_type = ce_policy.column_type(coldesc) decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size) - deserializer = find_deserializer(ce_policy.column_type(coldesc)) - val = from_binary(deserializer, &newbuf, desc.protocol_version) + deserializer = find_deserializer(ce_policy.column_type(coldesc), deserializer.protocol_version) + val = from_binary(deserializer, &newbuf) else: - val = from_binary(deserializer, &buf, desc.protocol_version) + val = from_binary(deserializer, &buf) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], desc.coltypes[i].cql_parameterized_type(), diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd index 27dc368b07..8fcee2c45e 100644 --- a/cassandra/parsing.pxd +++ b/cassandra/parsing.pxd @@ -21,7 +21,6 @@ cdef class ParseDesc: cdef public object column_encryption_policy cdef public list coldescs cdef Deserializer[::1] deserializers - cdef public int protocol_version cdef Py_ssize_t rowsize cdef class ColumnParser: diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx index 954767d227..c67f18fd88 100644 --- a/cassandra/parsing.pyx +++ b/cassandra/parsing.pyx @@ -19,13 +19,12 @@ Module containing the definitions and declarations (parsing.pxd) for parsers. cdef class ParseDesc: """Description of what structure to parse""" - def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version): + def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers): self.colnames = colnames self.coltypes = coltypes self.column_encryption_policy = column_encryption_policy self.coldescs = coldescs self.deserializers = deserializers - self.protocol_version = protocol_version self.rowsize = len(colnames) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e574965de8..eac91085c4 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -89,6 +89,7 @@ class _MessageType(object, metaclass=_RegisterMessageType): tracing = False custom_payload = None warnings = None + protocol_version = None def update_custom_payload(self, other): if other: @@ -126,13 +127,15 @@ def __init__(self, code, message, info): @classmethod def recv_body(cls, f, protocol_version, protocol_features, *args): code = read_int(f) - msg = read_string(f) + msg_text = read_string(f) if code == protocol_features.rate_limit_error: subcls = RateLimitReachedException else: subcls = error_classes.get(code, cls) - extra_info = subcls.recv_error_info(f, protocol_version) - return subcls(code=code, message=msg, info=extra_info) + msg = subcls(code=code, message=msg_text, info=None) + msg.protocol_version = protocol_version + msg.info = msg.recv_error_info(f) + return msg def summary_msg(self): msg = 'Error from server: code=%04x [%s] message="%s"' \ @@ -143,8 +146,7 @@ def __str__(self): return '<%s>' % self.summary_msg() __repr__ = __str__ - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): pass def to_exception(self): @@ -192,8 +194,7 @@ class UnavailableErrorMessage(RequestExecutionException): summary = 'Unavailable exception' error_code = 0x1000 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): return { 'consistency': read_consistency_level(f), 'required_replicas': read_int(f), @@ -223,8 +224,7 @@ class WriteTimeoutErrorMessage(RequestExecutionException): summary = "Coordinator node timed out waiting for replica nodes' responses" error_code = 0x1100 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): return { 'consistency': read_consistency_level(f), 'received_responses': read_int(f), @@ -240,8 +240,7 @@ class ReadTimeoutErrorMessage(RequestExecutionException): summary = "Coordinator node timed out waiting for replica nodes' responses" error_code = 0x1200 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): return { 'consistency': read_consistency_level(f), 'received_responses': read_int(f), @@ -257,13 +256,12 @@ class ReadFailureMessage(RequestExecutionException): summary = "Replica(s) failed to execute read" error_code = 0x1300 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): consistency = read_consistency_level(f) received_responses = read_int(f) required_responses = read_int(f) - if ProtocolVersion.uses_error_code_map(protocol_version): + if ProtocolVersion.uses_error_code_map(self.protocol_version): error_code_map = read_error_code_map(f) failures = len(error_code_map) else: @@ -289,8 +287,7 @@ class FunctionFailureMessage(RequestExecutionException): summary = "User Defined Function failure" error_code = 0x1400 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): return { 'keyspace': read_string(f), 'function': read_string(f), @@ -305,13 +302,12 @@ class WriteFailureMessage(RequestExecutionException): summary = "Replica(s) failed to execute write" error_code = 0x1500 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): consistency = read_consistency_level(f) received_responses = read_int(f) required_responses = read_int(f) - if ProtocolVersion.uses_error_code_map(protocol_version): + if ProtocolVersion.uses_error_code_map(self.protocol_version): error_code_map = read_error_code_map(f) failures = len(error_code_map) else: @@ -368,8 +364,7 @@ class PreparedQueryNotFound(RequestValidationException): summary = 'Matching prepared statement not found on this node' error_code = 0x2500 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): # return the query ID return read_binary_string(f) @@ -378,8 +373,7 @@ class AlreadyExistsException(ConfigurationException): summary = 'Item already exists' error_code = 0x2400 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): return { 'keyspace': read_string(f), 'table': read_string(f), @@ -392,8 +386,7 @@ class RateLimitReachedException(ConfigurationException): summary= 'Rate limit was exceeded for a partition affected by the request' error_code = 0x4321 - @staticmethod - def recv_error_info(f, protocol_version): + def recv_error_info(self, f): return { 'op_type': OperationType(read_byte(f)), 'rejected_by_coordinator': read_byte(f) != 0 @@ -601,7 +594,7 @@ def _write_query_params(self, f, protocol_version): if self.keyspace is not None: write_string(f, self.keyspace) - def _write_paging_options(self, f, paging_options, protocol_version): + def _write_paging_options(self, f, paging_options): write_int(f, paging_options.max_pages) write_int(f, paging_options.max_pages_per_second) @@ -691,17 +684,17 @@ class ResultMessage(_MessageType): def __init__(self, kind): self.kind = kind - def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): + def recv(self, f, protocol_features, user_type_map, result_metadata, column_encryption_policy): if self.kind == RESULT_KIND_VOID: return elif self.kind == RESULT_KIND_ROWS: - self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + self.recv_results_rows(f, self.protocol_version, user_type_map, result_metadata, column_encryption_policy) elif self.kind == RESULT_KIND_SET_KEYSPACE: self.new_keyspace = read_string(f) elif self.kind == RESULT_KIND_PREPARED: - self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map) + self.recv_results_prepared(f, protocol_features, user_type_map) elif self.kind == RESULT_KIND_SCHEMA_CHANGE: - self.recv_results_schema_change(f, protocol_version) + self.recv_results_schema_change(f) else: raise DriverException("Unknown RESULT kind: %d" % self.kind) @@ -709,7 +702,8 @@ def recv(self, f, protocol_version, protocol_features, user_type_map, result_met def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): kind = read_int(f) msg = cls(kind) - msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy) + msg.protocol_version = protocol_version + msg.recv(f, protocol_features, user_type_map, result_metadata, column_encryption_policy) return msg def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): @@ -725,7 +719,7 @@ def decode_val(val, col_md, col_desc): uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) + return col_type(protocol_version).from_binary(raw_bytes) def decode_row(row): return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) @@ -742,13 +736,13 @@ def decode_row(row): col_md[3].cql_parameterized_type(), str(e))) - def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): + def recv_results_prepared(self, f, protocol_features, user_type_map): self.query_id = read_binary_string(f) - if ProtocolVersion.uses_prepared_metadata(protocol_version): + if ProtocolVersion.uses_prepared_metadata(self.protocol_version): self.result_metadata_id = read_binary_string(f) else: self.result_metadata_id = None - self.recv_prepared_metadata(f, protocol_version, protocol_features, user_type_map) + self.recv_prepared_metadata(f, protocol_features, user_type_map) def recv_results_metadata(self, f, user_type_map): flags = read_int(f) @@ -786,12 +780,12 @@ def recv_results_metadata(self, f, user_type_map): self.column_metadata = column_metadata - def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_type_map): + def recv_prepared_metadata(self, f, protocol_features, user_type_map): flags = read_int(f) self.is_lwt = protocol_features.lwt_info.get_lwt_flag(flags) if protocol_features.lwt_info is not None else False colcount = read_int(f) pk_indexes = None - if protocol_version >= 4: + if self.protocol_version >= 4: num_pk_indexes = read_int(f) pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] @@ -816,8 +810,8 @@ def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_ty self.bind_metadata = bind_metadata self.pk_indexes = pk_indexes - def recv_results_schema_change(self, f, protocol_version): - self.schema_change_event = EventMessage.recv_schema_change(f, protocol_version) + def recv_results_schema_change(self, f): + self.schema_change_event = EventMessage.recv_schema_change(f) @classmethod def read_type(cls, f, user_type_map): @@ -985,11 +979,13 @@ def recv_body(cls, f, protocol_version, *args): event_type = read_string(f).upper() if event_type in known_event_types: read_method = getattr(cls, 'recv_' + event_type.lower()) - return cls(event_type=event_type, event_args=read_method(f, protocol_version)) + msg = cls(event_type=event_type, event_args=read_method(f)) + msg.protocol_version = protocol_version + return msg raise NotSupportedError('Unknown event type %r' % event_type) @classmethod - def recv_client_routes_change(cls, f, protocol_version): + def recv_client_routes_change(cls, f): # "UPDATE_NODES" change_type = read_string(f) connection_ids = read_stringlist(f) @@ -997,21 +993,21 @@ def recv_client_routes_change(cls, f, protocol_version): return dict(change_type=change_type, connection_ids=connection_ids, host_ids=host_ids) @classmethod - def recv_topology_change(cls, f, protocol_version): + def recv_topology_change(cls, f): # "NEW_NODE" or "REMOVED_NODE" change_type = read_string(f) address = read_inet(f) return dict(change_type=change_type, address=address) @classmethod - def recv_status_change(cls, f, protocol_version): + def recv_status_change(cls, f): # "UP" or "DOWN" change_type = read_string(f) address = read_inet(f) return dict(change_type=change_type, address=address) @classmethod - def recv_schema_change(cls, f, protocol_version): + def recv_schema_change(cls, f): # "CREATED", "DROPPED", or "UPDATED" change_type = read_string(f) target = read_string(f) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..bb379bc3bb 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -649,7 +649,7 @@ def bind(self, values): col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) uses_ce = ce_policy and ce_policy.contains_column(col_desc) col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) + col_bytes = col_type(proto_version).serialize(value) if uses_ce: col_bytes = ce_policy.encrypt(col_desc, col_bytes) self.values.append(col_bytes) diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..c2b1e95873 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -34,7 +34,7 @@ def make_recv_results_rows(ColumnParser colparser): desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, [ColDesc(md[0], md[1], md[2]) for md in column_metadata], - make_deserializers(self.column_types), protocol_version) + make_deserializers(self.column_types, protocol_version)) reader = BytesIOReader(f.read()) try: self.parsed_rows = colparser.parse_rows(reader, desc) diff --git a/cassandra/util.py b/cassandra/util.py index 12886d05ab..712254de12 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -767,17 +767,16 @@ def _serialize_key(self, key): class OrderedMapSerializedKey(OrderedMap): - def __init__(self, cass_type, protocol_version): + def __init__(self, cass_type_instance): super(OrderedMapSerializedKey, self).__init__() - self.cass_key_type = cass_type - self.protocol_version = protocol_version + self.cass_key_type = cass_type_instance def _insert_unchecked(self, key, flat_key, value): self._items.append((key, value)) self._index[flat_key] = len(self._items) - 1 def _serialize_key(self, key): - return self.cass_key_type.serialize(key, self.protocol_version) + return self.cass_key_type.serialize(key) @total_ordering diff --git a/tests/integration/long/test_loadbalancingpolicies.py b/tests/integration/long/test_loadbalancingpolicies.py index fd8edde14c..ec29129665 100644 --- a/tests/integration/long/test_loadbalancingpolicies.py +++ b/tests/integration/long/test_loadbalancingpolicies.py @@ -633,7 +633,7 @@ def test_token_aware_with_transient_replication(self): query = session.prepare("SELECT * FROM test_tr.users WHERE id = ?") for i in range(100): f = session.execute_async(query, (i,), trace=True) - full_dc1_replicas = [h for h in cluster.metadata.get_replicas('test_tr', cqltypes.Int32Type.serialize(i, cluster.protocol_version)) + full_dc1_replicas = [h for h in cluster.metadata.get_replicas('test_tr', cqltypes.Int32Type(cluster.protocol_version).serialize(i)) if h.datacenter == 'dc1'] assert len(full_dc1_replicas) == 2 diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 239f7e7336..947879410d 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -241,7 +241,7 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, self.column_types = [c[3] for c in column_metadata] self.checked_rev_row_set.update(self.column_types) self.parsed_rows = [ - tuple(ctype.from_binary(val, protocol_version) + tuple(ctype(protocol_version).from_binary(val) for ctype, val in zip(self.column_types, row)) for row in rows] diff --git a/tests/unit/advanced/test_geometry.py b/tests/unit/advanced/test_geometry.py index 1927b51da7..b7807b6740 100644 --- a/tests/unit/advanced/test_geometry.py +++ b/tests/unit/advanced/test_geometry.py @@ -35,13 +35,14 @@ class GeoTypes(unittest.TestCase): def test_marshal_platform(self): for proto_ver in protocol_versions: for geo in self.samples: - cql_type = lookup_casstype(geo.__class__.__name__ + 'Type') - assert cql_type.from_binary(cql_type.to_binary(geo, proto_ver), proto_ver) == geo + cql_type_class = lookup_casstype(geo.__class__.__name__ + 'Type') + cql_type = cql_type_class(proto_ver) + assert cql_type.from_binary(cql_type.to_binary(geo)) == geo def _verify_both_endian(self, typ, body_fmt, params, expected): for proto_ver in protocol_versions: - assert typ.from_binary(struct.pack(">BI" + body_fmt, wkb_be, *params), proto_ver) == expected - assert typ.from_binary(struct.pack("BI" + body_fmt, wkb_be, *params)) == expected + assert typ(proto_ver).from_binary(struct.pack(" 0 diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index e4b415ac69..b08f4b221b 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -75,7 +75,7 @@ (b'', 'MapType(AsciiType, BooleanType)', None), (b'', 'ListType(FloatType)', None), (b'', 'SetType(LongType)', None), - (b'\x00\x00\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 3)), + (b'\x00\x00\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType(3))), (b'\x00\x00\x00\x00', 'ListType(FloatType)', []), (b'\x00\x00\x00\x00', 'SetType(IntegerType)', sortedset()), (b'\x00\x00\x00\x01\x00\x00\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), @@ -88,7 +88,7 @@ (b'\x80\x00', 'ShortType', -32768) ) -ordered_map_value = OrderedMapSerializedKey(UTF8Type, 3) +ordered_map_value = OrderedMapSerializedKey(UTF8Type(3)) ordered_map_value._insert(u'\u307fbob', 199) ordered_map_value._insert(u'', -1) ordered_map_value._insert(u'\\', 0) @@ -111,20 +111,20 @@ class UnmarshalTest(unittest.TestCase): def test_unmarshalling(self): for serializedval, valtype, nativeval in marshalled_value_pairs: unmarshaller = lookup_casstype(valtype) - whatwegot = unmarshaller.from_binary(serializedval, 3) + whatwegot = unmarshaller(3).from_binary(serializedval) assert whatwegot == nativeval, 'Unmarshaller for %s (%s) failed: unmarshal(%r) got %r instead of %r' % (valtype, unmarshaller, serializedval, whatwegot, nativeval) assert type(whatwegot) == type(nativeval), 'Unmarshaller for %s (%s) gave wrong type (%s instead of %s)' % (valtype, unmarshaller, type(whatwegot), type(nativeval)) def test_marshalling(self): for serializedval, valtype, nativeval in marshalled_value_pairs: marshaller = lookup_casstype(valtype) - whatwegot = marshaller.to_binary(nativeval, 3) + whatwegot = marshaller(3).to_binary(nativeval) assert whatwegot == serializedval, 'Marshaller for %s (%s) failed: marshal(%r) got %r instead of %r' % (valtype, marshaller, nativeval, whatwegot, serializedval) assert type(whatwegot) == type(serializedval), 'Marshaller for %s (%s) gave wrong type (%s instead of %s)' % (valtype, marshaller, type(whatwegot), type(serializedval)) def test_date(self): # separate test because it will deserialize as datetime - assert DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 3), 3) == datetime(2015, 11, 2) + assert DateType(3).from_binary(DateType(3).to_binary(date(2015, 11, 2))) == datetime(2015, 11, 2) def test_decimal(self): # testing implicit numeric conversion @@ -132,6 +132,7 @@ def test_decimal(self): converted_types = (10001, (0, (1, 0, 0, 0, 0, 1), -3), 100.1, -87.629798) for proto_ver in range(3, ProtocolVersion.MAX_SUPPORTED + 1): + decimal_type = DecimalType(proto_ver) for n in converted_types: expected = Decimal(n) - assert DecimalType.from_binary(DecimalType.to_binary(n, proto_ver), proto_ver) == expected + assert decimal_type.from_binary(decimal_type.to_binary(n)) == expected diff --git a/tests/unit/test_orderedmap.py b/tests/unit/test_orderedmap.py index 156bbd5f30..a52f42290e 100644 --- a/tests/unit/test_orderedmap.py +++ b/tests/unit/test_orderedmap.py @@ -167,17 +167,18 @@ def test_delitem(self): class OrderedMapSerializedKeyTest(unittest.TestCase): def test_init(self): - om = OrderedMapSerializedKey(UTF8Type, 3) + om = OrderedMapSerializedKey(UTF8Type(3)) assert om == {} def test_normalized_lookup(self): - key_type = lookup_casstype('MapType(UTF8Type, Int32Type)') + key_type_class = lookup_casstype('MapType(UTF8Type, Int32Type)') protocol_version = 3 - om = OrderedMapSerializedKey(key_type, protocol_version) + key_type = key_type_class(protocol_version) + om = OrderedMapSerializedKey(key_type) key_ascii = {'one': 1} key_unicode = {u'two': 2} - om._insert_unchecked(key_ascii, key_type.serialize(key_ascii, protocol_version), object()) - om._insert_unchecked(key_unicode, key_type.serialize(key_unicode, protocol_version), object()) + om._insert_unchecked(key_ascii, key_type.serialize(key_ascii), object()) + om._insert_unchecked(key_unicode, key_type.serialize(key_unicode), object()) # type lookup is normalized by key_type # PYTHON-231 diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..824332ead4 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -135,7 +135,7 @@ def test_keyspace_flag_raises_before_v5(self): io = Mock(name='io') with pytest.raises(UnsupportedOperation, match='Keyspaces.*set'): - keyspace_message.send_body(io, protocol_version=4) + keyspace_message.send_body(io, 4) io.assert_not_called() def test_keyspace_written_with_length(self): @@ -147,9 +147,8 @@ def test_keyspace_written_with_length(self): (b'\x00\x00\x00\x80',), # options w/ keyspace flag ] - QueryMessage('a', consistency_level=3, keyspace='ks').send_body( - io, protocol_version=5 - ) + msg = QueryMessage('a', consistency_level=3, keyspace='ks') + msg.send_body(io, 5) self._check_calls(io, base_expected + [ (b'\x00\x02',), # length of keyspace string (b'ks',), @@ -157,9 +156,8 @@ def test_keyspace_written_with_length(self): io.reset_mock() - QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( - io, protocol_version=5 - ) + msg = QueryMessage('a', consistency_level=3, keyspace='keyspace') + msg.send_body(io, 5) self._check_calls(io, base_expected + [ (b'\x00\x08',), # length of keyspace string (b'keyspace',), @@ -177,7 +175,7 @@ def test_batch_message_with_keyspace(self): consistency_level=3, keyspace='ks' ) - batch.send_body(io, protocol_version=5) + batch.send_body(io, 5) self._check_calls(io, ((b'\x00',), (b'\x00\x03',), (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt a',), diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 3390f6dbd6..aa9f2cda58 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -218,28 +218,28 @@ def test_datetype(self): now_timestamp = now_time_seconds * 1e3 # same results serialized - assert DateType.serialize(now_datetime, 0) == DateType.serialize(now_timestamp, 0) + assert DateType(0).serialize(now_datetime) == DateType(0).serialize(now_timestamp) # deserialize # epoc expected = 0 - assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None) + assert DateType(0).deserialize(int64_pack(1000 * expected)) == datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None) # beyond 32b expected = 2 ** 33 - assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime(2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType(0).deserialize(int64_pack(1000 * expected)) == datetime.datetime(2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc).replace(tzinfo=None) # less than epoc (PYTHON-119) expected = -770172256 - assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime(1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType(0).deserialize(int64_pack(1000 * expected)) == datetime.datetime(1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc).replace(tzinfo=None) # work around rounding difference among Python versions (PYTHON-230) expected = 1424817268.274 - assert DateType.deserialize(int64_pack(int(1000 * expected)), 0) == datetime.datetime(2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType(0).deserialize(int64_pack(int(1000 * expected))) == datetime.datetime(2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) # Large date overflow (PYTHON-452) expected = 2177403010.123 - assert DateType.deserialize(int64_pack(int(1000 * expected)), 0) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) + assert DateType(0).deserialize(int64_pack(int(1000 * expected))) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) def test_collection_null_support(self): """ @@ -254,10 +254,10 @@ def test_collection_null_support(self): int32_pack(4) + # size of item2 int32_pack(42) # item2 ) - assert [None, 42] == int_list.deserialize(value, 3) + assert [None, 42] == int_list(3).deserialize(value) set_list = SetType.apply_parameters([Int32Type]) - assert {None, 42} == set(set_list.deserialize(value, 3)) + assert {None, 42} == set(set_list(3).deserialize(value)) value = ( int32_pack(2) + # num items @@ -271,7 +271,7 @@ def test_collection_null_support(self): map_list = MapType.apply_parameters([Int32Type, Int32Type]) - assert [(42, None), (None, 42)] == map_list.deserialize(value, 3)._items # OrderedMapSerializedKey + assert [(42, None), (None, 42)] == map_list(3).deserialize(value)._items # OrderedMapSerializedKey def test_write_read_string(self): with tempfile.TemporaryFile() as f: @@ -340,11 +340,11 @@ def _round_trip_compare_fn(self, first, second): def _round_trip_test(self, data, ctype_str): ctype = parse_casstype_args(ctype_str) - data_bytes = ctype.serialize(data, 0) + data_bytes = ctype(0).serialize(data) serialized_size = ctype.subtype.serial_size() if serialized_size: assert serialized_size * len(data) == len(data_bytes) - result = ctype.deserialize(data_bytes, 0) + result = ctype(0).deserialize(data_bytes) assert len(data) == len(result) for idx in range(0,len(data)): self._round_trip_compare_fn(data[idx], result[idx]) @@ -464,50 +464,50 @@ def test_cql_parameterized_type(self): def test_serialization_fixed_size_too_small(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") with pytest.raises(ValueError, match="Expected sequence of size 5 for vector of type float and dimension 5, observed sequence of length 4"): - ctype.serialize([1.2, 3.4, 5.6, 7.8], 0) + ctype(0).serialize([1.2, 3.4, 5.6, 7.8]) def test_serialization_fixed_size_too_big(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") with pytest.raises(ValueError, match="Expected sequence of size 4 for vector of type float and dimension 4, observed sequence of length 5"): - ctype.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) + ctype(0).serialize([1.2, 3.4, 5.6, 7.8, 9.10]) def test_serialization_variable_size_too_small(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") with pytest.raises(ValueError, match="Expected sequence of size 5 for vector of type varint and dimension 5, observed sequence of length 4"): - ctype.serialize([1, 2, 3, 4], 0) + ctype(0).serialize([1, 2, 3, 4]) def test_serialization_variable_size_too_big(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") with pytest.raises(ValueError, match="Expected sequence of size 4 for vector of type varint and dimension 4, observed sequence of length 5"): - ctype.serialize([1, 2, 3, 4, 5], 0) + ctype(0).serialize([1, 2, 3, 4, 5]) def test_deserialization_fixed_size_too_small(self): ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - ctype_four_bytes = ctype_four.serialize([1.2, 3.4, 5.6, 7.8], 0) + ctype_four_bytes = ctype_four(0).serialize([1.2, 3.4, 5.6, 7.8]) ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") with pytest.raises(ValueError, match="Expected vector of type float and dimension 5 to have serialized size 20; observed serialized size of 16 instead"): - ctype_five.deserialize(ctype_four_bytes, 0) + ctype_five(0).deserialize(ctype_four_bytes) def test_deserialization_fixed_size_too_big(self): ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") - ctype_five_bytes = ctype_five.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) + ctype_five_bytes = ctype_five(0).serialize([1.2, 3.4, 5.6, 7.8, 9.10]) ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") with pytest.raises(ValueError, match="Expected vector of type float and dimension 4 to have serialized size 16; observed serialized size of 20 instead"): - ctype_four.deserialize(ctype_five_bytes, 0) + ctype_four(0).deserialize(ctype_five_bytes) def test_deserialization_variable_size_too_small(self): ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") - ctype_four_bytes = ctype_four.serialize([1, 2, 3, 4], 0) + ctype_four_bytes = ctype_four(0).serialize([1, 2, 3, 4]) ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") with pytest.raises(ValueError, match="Error reading additional data during vector deserialization after successfully adding 4 elements"): - ctype_five.deserialize(ctype_four_bytes, 0) + ctype_five(0).deserialize(ctype_four_bytes) def test_deserialization_variable_size_too_big(self): ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") - ctype_five_bytes = ctype_five.serialize([1, 2, 3, 4, 5], 0) + ctype_five_bytes = ctype_five(0).serialize([1, 2, 3, 4, 5]) ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") with pytest.raises(ValueError, match="Additional bytes remaining after vector deserialization completed"): - ctype_four.deserialize(ctype_five_bytes, 0) + ctype_four(0).deserialize(ctype_five_bytes) ZERO = datetime.timedelta(0) @@ -575,7 +575,7 @@ def test_deserialize_single_value(self): serialized = (int8_pack(0) + int64_pack(self.timestamp) + int8_pack(3)) - assert DateRangeType.deserialize(serialized, 5) == util.DateRange(value=util.DateRangeBound( + assert DateRangeType(5).deserialize(serialized) == util.DateRange(value=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), precision='HOUR') ) @@ -586,7 +586,7 @@ def test_deserialize_closed_range(self): int8_pack(2) + int64_pack(self.timestamp) + int8_pack(6)) - assert DateRangeType.deserialize(serialized, 5) == util.DateRange( + assert DateRangeType(5).deserialize(serialized) == util.DateRange( lower_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 0, 0), precision='DAY' @@ -601,7 +601,7 @@ def test_deserialize_open_high(self): serialized = (int8_pack(2) + int64_pack(self.timestamp) + int8_pack(3)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15, 0), @@ -614,7 +614,7 @@ def test_deserialize_open_low(self): serialized = (int8_pack(3) + int64_pack(self.timestamp) + int8_pack(4)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( lower_bound=util.OPEN_BOUND, upper_bound=util.DateRangeBound( @@ -624,13 +624,13 @@ def test_deserialize_open_low(self): ) def test_deserialize_single_open(self): - assert util.DateRange(value=util.OPEN_BOUND) == DateRangeType.deserialize(int8_pack(5), 5) + assert util.DateRange(value=util.OPEN_BOUND) == DateRangeType(5).deserialize(int8_pack(5)) def test_serialize_single_value(self): serialized = (int8_pack(0) + int64_pack(self.timestamp) + int8_pack(5)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( value=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15, 42, 12), @@ -644,7 +644,7 @@ def test_serialize_closed_range(self): int8_pack(5) + int64_pack(self.timestamp) + int8_pack(0)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15, 42, 12), @@ -660,7 +660,7 @@ def test_serialize_open_high(self): serialized = (int8_pack(2) + int64_pack(self.timestamp) + int8_pack(2)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1), @@ -673,7 +673,7 @@ def test_serialize_open_low(self): serialized = (int8_pack(2) + int64_pack(self.timestamp) + int8_pack(3)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( lower_bound=util.DateRangeBound( value=datetime.datetime(2017, 2, 1, 15), @@ -684,40 +684,40 @@ def test_serialize_open_low(self): def test_deserialize_both_open(self): serialized = (int8_pack(4)) - deserialized = DateRangeType.deserialize(serialized, 5) + deserialized = DateRangeType(5).deserialize(serialized) assert deserialized == util.DateRange( lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND ) def test_serialize_single_open(self): - serialized = DateRangeType.serialize(util.DateRange( + serialized = DateRangeType(5).serialize(util.DateRange( value=util.OPEN_BOUND, - ), 5) + )) assert int8_pack(5) == serialized def test_serialize_both_open(self): - serialized = DateRangeType.serialize(util.DateRange( + serialized = DateRangeType(5).serialize(util.DateRange( lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND - ), 5) + )) assert int8_pack(4) == serialized def test_failure_to_serialize_no_value_object(self): with pytest.raises(ValueError): - DateRangeType.serialize(object(), 5) + DateRangeType(5).serialize(object()) def test_failure_to_serialize_no_bounds_object(self): class no_bounds_object(object): value = lower_bound = None with pytest.raises(ValueError): - DateRangeType.serialize(no_bounds_object, 5) + DateRangeType(5).serialize(no_bounds_object) def test_serialized_value_round_trip(self): vals = [b'\x01\x00\x00\x01%\xe9a\xf9\xd1\x06\x00\x00\x01v\xbb>o\xff\x00', b'\x01\x00\x00\x00\xdcm\x03-\xd1\x06\x00\x00\x01v\xbb>o\xff\x00'] for serialized in vals: - assert serialized == DateRangeType.serialize(DateRangeType.deserialize(serialized, 0), 0) + assert serialized == DateRangeType(0).serialize(DateRangeType(0).deserialize(serialized)) def test_serialize_zero_datetime(self): """ @@ -731,10 +731,10 @@ def test_serialize_zero_datetime(self): @test_category data_types """ - DateRangeType.serialize(util.DateRange( + DateRangeType(5).serialize(util.DateRange( lower_bound=(datetime.datetime(1970, 1, 1), 'YEAR'), upper_bound=(datetime.datetime(1970, 1, 1), 'YEAR') - ), 5) + )) def test_deserialize_zero_datetime(self): """ @@ -748,11 +748,10 @@ def test_deserialize_zero_datetime(self): @test_category data_types """ - DateRangeType.deserialize( + DateRangeType(5).deserialize( (int8_pack(1) + int64_pack(0) + int8_pack(0) + - int64_pack(0) + int8_pack(0)), - 5 + int64_pack(0) + int8_pack(0)) )