|
26 | 26 |
|
27 | 27 |
|
28 | 28 | __all__ = ['DecodeError', 'TlvModel', 'ProcedureArgument', 'OffsetMarker', 'UintField', 'BoolField', |
29 | | - 'NameField', 'BytesField', 'ModelField', 'RepeatedField', 'IncludeBase', 'IncludeBaseError'] |
| 29 | + 'NameField', 'BytesField', 'ModelField', 'RepeatedField', 'IncludeBase', 'IncludeBaseError', |
| 30 | + 'MapField'] |
30 | 31 |
|
31 | 32 |
|
32 | 33 | class DecodeError(Exception): |
@@ -724,6 +725,8 @@ def asdict(self, dict_factory=dict): |
724 | 725 | result.append((field.name, field.__get__(self, None).asdict())) |
725 | 726 | elif isinstance(field, RepeatedField): |
726 | 727 | result.append((field.name, field.aslist(self))) |
| 728 | + elif isinstance(field, MapField): |
| 729 | + result.append((field.name, field.asdict(self))) |
727 | 730 | elif isinstance(field, BytesField): |
728 | 731 | val = field.__get__(self, None) |
729 | 732 | if isinstance(val, str): |
@@ -819,11 +822,25 @@ def parse(cls, wire: BinaryStr, markers: Optional[dict] = None, ignore_critical: |
819 | 822 | for j in range(field_pos, i): |
820 | 823 | ret._encoded_fields[j].skipping_process(markers, wire, offset_btl) |
821 | 824 | # Parse that field |
822 | | - val = ret._encoded_fields[i].parse_from(ret, markers, wire, offset, length, offset_btl) |
823 | | - ret._encoded_fields[i].__set__(ret, val) |
| 825 | + cur_field = ret._encoded_fields[i] |
| 826 | + val = cur_field.parse_from(ret, markers, wire, offset, length, offset_btl) |
| 827 | + cur_field.__set__(ret, val) |
824 | 828 | # Set next field |
825 | | - if isinstance(ret._encoded_fields[i], RepeatedField): |
| 829 | + if isinstance(cur_field, RepeatedField): |
826 | 830 | field_pos = i |
| 831 | + elif isinstance(cur_field, MapField): |
| 832 | + # Parse the value part for a map |
| 833 | + field_pos = i |
| 834 | + offset += length |
| 835 | + |
| 836 | + offset_btl = offset |
| 837 | + typ, size_typ = parse_tl_num(wire, offset) |
| 838 | + offset += size_typ |
| 839 | + length, size_len = parse_tl_num(wire, offset) |
| 840 | + offset += size_len |
| 841 | + |
| 842 | + val = cur_field.parse_value(ret, markers, wire, offset, length, offset_btl) |
| 843 | + cur_field.__set__(ret, val) |
827 | 844 | else: |
828 | 845 | field_pos = i + 1 |
829 | 846 | else: |
@@ -966,3 +983,87 @@ def aslist(self, instance): |
966 | 983 | else: |
967 | 984 | ret.append(x) |
968 | 985 | return ret |
| 986 | + |
| 987 | + |
| 988 | +class MapField(Field): |
| 989 | + r""" |
| 990 | + Field for an unordered string or int map of a specific type. |
| 991 | + All elements will be directly encoded into TLV wire in order, sharing the same Type. |
| 992 | + The ``type_num`` of ``element_type`` is used. |
| 993 | +
|
| 994 | + Type: :class:`list` |
| 995 | +
|
| 996 | + :vartype value_type: :any:`Field` |
| 997 | + :ivar value_type: the type of values in the dict. |
| 998 | +
|
| 999 | + .. warning:: |
| 1000 | +
|
| 1001 | + Please always create a new :any:`Field` instance. |
| 1002 | + Don't use an existing one. |
| 1003 | + """ |
| 1004 | + |
| 1005 | + def __init__(self, key_type: Field, value_type: Field): |
| 1006 | + # default should be None here to prevent unintended modification |
| 1007 | + if not isinstance(key_type, BytesField) and not isinstance(key_type, UintField): |
| 1008 | + raise TypeError('MapField only supports string and uint to be keys') |
| 1009 | + super().__init__(key_type.type_num, None) |
| 1010 | + self.key_type = key_type |
| 1011 | + self.value_type = value_type |
| 1012 | + |
| 1013 | + def get_value(self, instance): |
| 1014 | + if self.name not in instance.__dict__: |
| 1015 | + instance.__dict__[self.name] = {} |
| 1016 | + return instance.__dict__[self.name] |
| 1017 | + |
| 1018 | + def encoded_length(self, val, markers: dict) -> int: |
| 1019 | + if not val: |
| 1020 | + return 0 |
| 1021 | + |
| 1022 | + ret = 0 |
| 1023 | + for i, (key, val) in enumerate(val.items()): |
| 1024 | + self.key_type.name = f'{self.name}[{i}#k]' |
| 1025 | + ret += self.key_type.encoded_length(key, markers) |
| 1026 | + self.value_type.name = f'{self.name}[{i}#v]' |
| 1027 | + ret += self.value_type.encoded_length(val, markers) |
| 1028 | + |
| 1029 | + return ret |
| 1030 | + |
| 1031 | + def encode_into(self, val, markers: dict, wire: VarBinaryStr, offset: int) -> int: |
| 1032 | + if val is None: |
| 1033 | + return 0 |
| 1034 | + else: |
| 1035 | + origin_offset = offset |
| 1036 | + for i, (key, val) in enumerate(val.items()): |
| 1037 | + self.key_type.name = f'{self.name}[{i}#k]' |
| 1038 | + offset += self.key_type.encode_into(key, markers, wire, offset) |
| 1039 | + self.value_type.name = f'{self.name}[{i}#v]' |
| 1040 | + offset += self.value_type.encode_into(val, markers, wire, offset) |
| 1041 | + return offset - origin_offset |
| 1042 | + |
| 1043 | + def parse_from(self, instance, markers: dict, wire: BinaryStr, offset: int, length: int, offset_btl: int): |
| 1044 | + # parse_from only parses keys and will not update the value |
| 1045 | + dct = self.get_value(instance) |
| 1046 | + self.key_type.name = f'{self.name}[{len(dct)}#k]' |
| 1047 | + new_key = self.key_type.parse_from(instance, markers, wire, offset, length, offset_btl) |
| 1048 | + markers[f'{self.name}#last_key'] = new_key |
| 1049 | + return dct |
| 1050 | + |
| 1051 | + def parse_value(self, instance, markers: dict, wire: BinaryStr, offset: int, length: int, offset_btl: int): |
| 1052 | + # parse_value parses the value associated with the key last parsed. |
| 1053 | + dct = self.get_value(instance) |
| 1054 | + last_key = markers.get(f'{self.name}#last_key') |
| 1055 | + self.value_type.name = f'{self.name}[{len(dct)}#v]' |
| 1056 | + val = self.value_type.parse_from(instance, markers, wire, offset, length, offset_btl) |
| 1057 | + dct[last_key] = val |
| 1058 | + return dct |
| 1059 | + |
| 1060 | + def asdict(self, instance): |
| 1061 | + ret = {} |
| 1062 | + for key, val in self.__get__(instance, None).items(): |
| 1063 | + if isinstance(val, TlvModel): |
| 1064 | + ret[key] = val.asdict() |
| 1065 | + elif isinstance(val, memoryview): |
| 1066 | + ret[key] = bytes(val) |
| 1067 | + else: |
| 1068 | + ret[key] = val |
| 1069 | + return ret |
0 commit comments