Skip to content

Commit de22e2d

Browse files
committed
Make protocol objects immutable
All objects are now fully constucted via `__init__` or `deserialize`. Additionally, array fields are now stored as tuples. Besides giving us deeper immutability, this fixes a bug where array fields could be out of sync with their corresponding length fields if the array was mutated after being assigned via setter.
1 parent b3def84 commit de22e2d

File tree

5 files changed

+154
-130
lines changed

5 files changed

+154
-130
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919

2020
### Changed
2121

22+
- Make protocol objects immutable.
23+
- Protocol data structures are now fully instantiated via `__init__`, and can't be modified later.
24+
- Array fields are now typed as tuples.
2225
- Rename `QuestReportServerPacket.npc_id` field to `npc_index`.
2326
- Make `CastReplyServerPacket.caster_tp` field optional.
2427
- Make `CastSpecServerPacket.caster_tp` field optional.
@@ -34,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3437

3538
### Removed
3639

40+
- All protocol object field setters.
3741
- `EffectPlayerServerPacket.player_id` field.
3842
- `EffectPlayerServerPacket.effect_id` field.
3943
- `EffectAgreeServerPacket.coords` field.

protocol_code_generator/generate/field_code_generator.py

Lines changed: 82 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import namedtuple
1+
from html import escape
22
from protocol_code_generator.generate.code_block import CodeBlock
33
from protocol_code_generator.generate.object_code_generator import FieldData
44
from protocol_code_generator.type.basic_type import BasicType
@@ -14,19 +14,6 @@
1414
from protocol_code_generator.util.docstring_utils import generate_docstring
1515
from protocol_code_generator.util.number_utils import try_parse_int
1616

17-
DeprecatedField = namedtuple(
18-
"DeprecatedField", ["type_name", "old_field_name", "new_field_name", "since"]
19-
)
20-
21-
DEPRECATED_FIELDS = []
22-
23-
24-
def get_deprecated_field(type_name, field_name):
25-
for field in DEPRECATED_FIELDS:
26-
if field.type_name == type_name and field.new_field_name == field_name:
27-
return field
28-
return None
29-
3017

3118
class FieldCodeGenerator:
3219
def __init__(
@@ -182,28 +169,24 @@ def generate_field(self):
182169
field_type = self._get_type()
183170

184171
python_type_name = self._get_python_type_name(field_type)
172+
python_param_type_name = python_type_name
173+
185174
if self._array_field:
186-
python_type_name = f"list[{python_type_name}]"
175+
python_type_name = f"tuple[{python_type_name}, ...]"
176+
python_param_type_name = f"Iterable[{python_param_type_name}]"
187177
self._data.fields.add_import('annotations', '__future__')
178+
self._data.fields.add_import('Iterable', 'collections.abc')
188179

189180
if self._optional:
190181
python_type_name = f"Optional[{python_type_name}]"
182+
python_param_type_name = f"Optional[{python_param_type_name}]"
191183
self._data.fields.add_import('Optional', 'typing')
192184

193-
if self._hardcoded_value is None:
194-
initializer = None
195-
elif isinstance(field_type, StringType):
196-
initializer = f'"{self._hardcoded_value}"'
197-
else:
198-
initializer = self._hardcoded_value
199-
200185
self._context.accessible_fields[self._name] = FieldData(
201186
self._name, field_type, self._offset, self._array_field
202187
)
203188

204-
self._data.fields.add_line(
205-
f"_{self._name}: {python_type_name} = {initializer} # type: ignore [assignment]"
206-
)
189+
self._data.fields.add_line(f"_{self._name}: {python_type_name}")
207190

208191
if isinstance(field_type, CustomType):
209192
self._data.fields.add_import_by_type(field_type)
@@ -212,82 +195,54 @@ def generate_field(self):
212195
self._context.length_field_is_referenced_map[self._name] = False
213196
return
214197

215-
docstring = self._generate_accessor_docstring()
198+
self._data.init_docstring_params.append(
199+
CodeBlock()
200+
.add(f'{self._name} ({python_param_type_name}): ')
201+
.add_code_block(self._generate_init_docstring())
202+
)
216203

217204
self._data.add_method(
218205
CodeBlock()
219206
.add_line('@property')
220207
.add_line(f'def {self._name}(self) -> {python_type_name}:')
221208
.indent()
222-
.add_code_block(docstring)
209+
.add_code_block(generate_docstring(self._comment))
223210
.add_line(f'return self._{self._name}')
224211
.unindent()
225212
)
226213

227214
self._data.repr_fields.append(self._name)
228215

229216
if self._hardcoded_value is None:
230-
setter = (
231-
CodeBlock()
232-
.add_line(f'@{self._name}.setter')
233-
.add_line(f'def {self._name}(self, {self._name}: {python_type_name}) -> None:')
234-
.indent()
235-
.add_code_block(docstring)
236-
.add_line(f'self._{self._name} = {self._name}')
237-
)
217+
expression = self._name
218+
if self._array_field:
219+
expression = f'tuple({expression})'
220+
elif isinstance(field_type, StringType):
221+
expression = f'"{self._hardcoded_value}"'
222+
else:
223+
expression = self._hardcoded_value
238224

239-
if self._length_string in self._context.length_field_is_referenced_map:
240-
self._context.length_field_is_referenced_map[self._length_string] = True
241-
length_field_data = self._context.accessible_fields[self._length_string]
242-
setter.add_line(f'self._{length_field_data.name} = len(self._{self._name})')
243-
244-
setter.unindent()
245-
self._data.add_method(setter)
246-
247-
deprecated = get_deprecated_field(self._data.class_name, self._name)
248-
if deprecated is not None:
249-
old_name = deprecated.old_field_name
250-
deprecated_docstring = (
251-
CodeBlock()
252-
.add_line('"""')
253-
.add_line('!!! warning "Deprecated"')
254-
.add_line()
255-
.add_line(f" Use `{self._name}` instead. (Deprecated since v{deprecated.since})")
256-
.add_line('"""')
257-
)
258-
deprecation_warning = (
259-
f"'{self._data.class_name}.{deprecated.old_field_name}' is deprecated as of "
260-
f"{deprecated.since}, use '{self._name}' instead."
261-
)
262-
self._data.add_method(
263-
CodeBlock()
264-
.add_line('@property')
265-
.add_line(f'def {old_name}(self) -> {python_type_name}:')
266-
.indent()
267-
.add_code_block(deprecated_docstring)
268-
.add_line(f'warn("{deprecation_warning}", DeprecationWarning, stacklevel=2)')
269-
.add_line(f'return self.{self._name}')
270-
.unindent()
271-
.add_import("warn", "warnings")
225+
init_param = CodeBlock().add(f'{self._name}: {python_param_type_name}')
226+
if self._optional:
227+
init_param.add(' = None')
228+
229+
self._data.init_params.append(init_param)
230+
self._data.init_body.add_line(f'self._{self._name} = {expression}')
231+
232+
if self._length_string in self._context.length_field_is_referenced_map:
233+
self._context.length_field_is_referenced_map[self._length_string] = True
234+
length_field_data = self._context.accessible_fields[self._length_string]
235+
self._data.init_body.add_line(
236+
f'self._{length_field_data.name} = len(self._{self._name})'
272237
)
273-
if self._hardcoded_value is None:
274-
self._data.add_method(
275-
CodeBlock()
276-
.add_line(f'@{old_name}.setter')
277-
.add_line(f'def {old_name}(self, {self._name}: {python_type_name}) -> None:')
278-
.indent()
279-
.add_code_block(deprecated_docstring)
280-
.add_line(f'self.{self._name} = {self._name}')
281-
.unindent()
282-
)
283238

284239
def generate_serialize(self):
285240
self._generate_serialize_missing_optional_guard()
286241
self._generate_serialize_none_not_allowed_error()
287242
self._generate_serialize_length_check()
288243

289244
if self._array_field:
290-
array_size_expression = self._get_length_expression()
245+
array_size_expression = self._get_serialize_length_expression()
291246
if array_size_expression is None:
292247
array_size_expression = f"len(data._{self._name})"
293248

@@ -310,6 +265,8 @@ def generate_serialize(self):
310265

311266
def generate_deserialize(self):
312267
if self._optional:
268+
python_type = f'Optional[{self._get_python_type_name(self._get_type())}]'
269+
self._data.deserialize.add_line(f'{self._name}: {python_type} = None')
313270
self._data.deserialize.begin_control_flow("if reader.remaining > 0")
314271

315272
if self._array_field:
@@ -320,8 +277,20 @@ def generate_deserialize(self):
320277
if self._optional:
321278
self._data.deserialize.unindent()
322279

323-
def _generate_accessor_docstring(self):
324-
notes = []
280+
if self._name is not None and not self._length_field:
281+
self._data.deserialize_init_arguments.append(
282+
CodeBlock().add(f'{self._name}={self._name}')
283+
)
284+
285+
def _generate_init_docstring(self):
286+
result = CodeBlock()
287+
288+
if self._comment is not None:
289+
lines = map(str.strip, escape(self._comment, quote=False).split('\n'))
290+
for line in lines:
291+
if not result.empty:
292+
result.add(' ')
293+
result.add(line)
325294

326295
if self._length_string is not None:
327296
size_description = ""
@@ -333,14 +302,18 @@ def _generate_accessor_docstring(self):
333302
size_description = f'`{self._length_string}`'
334303
if self._padded:
335304
size_description += " or less"
336-
notes.append(f'Length must be {size_description}.')
305+
if not result.empty:
306+
result.add(' ')
307+
result.add(f'(Length must be {size_description}.)')
337308

338309
field_type = self._get_type()
339310
if isinstance(field_type, IntegerType):
340311
value_description = "Element value" if self._array_field else "Value"
341-
notes.append(f'{value_description} range is 0-{get_max_value_of(field_type)}.')
312+
if not result.empty:
313+
result.add(' ')
314+
result.add(f'({value_description} range is 0-{get_max_value_of(field_type)}.)')
342315

343-
return generate_docstring(self._comment, notes)
316+
return result
344317

345318
def _generate_serialize_missing_optional_guard(self):
346319
if not self._optional:
@@ -421,7 +394,9 @@ def _get_write_statement(self):
421394
result = CodeBlock()
422395

423396
if isinstance(type_, BasicType):
424-
length_expression = None if self._array_field else self._get_length_expression()
397+
length_expression = (
398+
None if self._array_field else self._get_serialize_length_expression()
399+
)
425400
write_statement = FieldCodeGenerator._get_write_statement_for_basic_type(
426401
type_, value_expression, length_expression, self._padded
427402
)
@@ -492,7 +467,7 @@ def _get_write_statement_for_basic_type(type_, value_expression, length_expressi
492467
raise AssertionError("Unhandled BasicType")
493468

494469
def _generate_deserialize_array(self):
495-
array_length_expression = self._get_length_expression()
470+
array_length_expression = self._get_deserialize_length_expression()
496471

497472
if array_length_expression is None and not self._delimited:
498473
element_size = self._get_type().fixed_size
@@ -503,7 +478,7 @@ def _generate_deserialize_array(self):
503478
)
504479
array_length_expression = array_length_variable_name
505480

506-
self._data.deserialize.add_line(f"data._{self._name} = []")
481+
self._data.deserialize.add_line(f"{self._name} = []")
507482

508483
if array_length_expression is None:
509484
self._data.deserialize.begin_control_flow("while reader.remaining > 0")
@@ -532,12 +507,14 @@ def _get_read_statement(self):
532507
statement = CodeBlock()
533508

534509
if self._array_field:
535-
statement.add(f"data._{self._name}.append(")
510+
statement.add(f"{self._name}.append(")
536511
elif self._name is not None:
537-
statement.add(f"data._{self._name} = ")
512+
statement.add(f"{self._name} = ")
538513

539514
if isinstance(type_, BasicType):
540-
length_expression = None if self._array_field else self._get_length_expression()
515+
length_expression = (
516+
None if self._array_field else self._get_deserialize_length_expression()
517+
)
541518
read_basic_type = FieldCodeGenerator._get_read_statement_for_basic_type(
542519
type_, length_expression, self._padded
543520
)
@@ -564,17 +541,24 @@ def _get_read_statement(self):
564541

565542
return statement.add("\n")
566543

567-
def _get_length_expression(self):
568-
if self._length_string is None:
569-
return None
544+
def _check_field_accessible(self, field):
545+
field_data = self._context.accessible_fields.get(field)
546+
if not field_data:
547+
raise RuntimeError(f'Referenced {field} field is not accessible.')
570548

549+
def _get_serialize_length_expression(self):
571550
expression = self._length_string
572-
if not expression.isdigit():
573-
field_data = self._context.accessible_fields.get(expression)
574-
if not field_data:
575-
raise RuntimeError(f'Referenced {expression} field is not accessible.')
576-
expression = f'data._{expression}'
551+
if expression is not None:
552+
if not expression.isdigit():
553+
self._check_field_accessible(expression)
554+
expression = f'data._{expression}'
555+
return expression
577556

557+
def _get_deserialize_length_expression(self):
558+
expression = self._length_string
559+
if expression is not None:
560+
if not expression.isdigit():
561+
self._check_field_accessible(expression)
578562
return expression
579563

580564
@staticmethod

0 commit comments

Comments
 (0)