diff --git a/betterproto2/src/betterproto2/__init__.py b/betterproto2/src/betterproto2/__init__.py index dc159eb5..57d674c9 100644 --- a/betterproto2/src/betterproto2/__init__.py +++ b/betterproto2/src/betterproto2/__init__.py @@ -611,6 +611,8 @@ def _value_from_dict(value: Any, meta: FieldMetadata, field_type: type) -> Any: if meta.proto_type == TYPE_ENUM: if isinstance(value, str): + if (int_value := field_type.betterproto_renamed_proto_names_to_value().get(value)) is not None: + return field_type(int_value) return field_type.from_string(value) if isinstance(value, int): return field_type(value) diff --git a/betterproto2/src/betterproto2/enum_.py b/betterproto2/src/betterproto2/enum_.py index 081a6078..5708bf8e 100644 --- a/betterproto2/src/betterproto2/enum_.py +++ b/betterproto2/src/betterproto2/enum_.py @@ -1,4 +1,3 @@ -import sys from enum import EnumMeta, IntEnum from typing_extensions import Self @@ -6,17 +5,8 @@ class _EnumMeta(EnumMeta): def __new__(metacls, cls, bases, classdict): - # Find the proto names if defined - if sys.version_info >= (3, 11): - proto_names = classdict.pop("betterproto_proto_names", {}) - classdict._member_names.pop("betterproto_proto_names", None) - else: - proto_names = {} - if "betterproto_proto_names" in classdict: - proto_names = classdict.pop("betterproto_proto_names") - classdict._member_names.remove("betterproto_proto_names") - enum_class = super().__new__(metacls, cls, bases, classdict) + proto_names = enum_class.betterproto_value_to_renamed_proto_names() # type: ignore[reportAttributeAccessIssue] # Attach extra info to each enum member for member in enum_class: @@ -32,6 +22,14 @@ class Enum(IntEnum, metaclass=_EnumMeta): def proto_name(self) -> str | None: return self._proto_name # type: ignore[reportAttributeAccessIssue] + @classmethod + def betterproto_value_to_renamed_proto_names(cls) -> dict[int, str]: + return {} + + @classmethod + def betterproto_renamed_proto_names_to_value(cls) -> dict[str, int]: + return {} + @classmethod def _missing_(cls, value): # If the given value is not an integer, let the standard enum implementation raise an error diff --git a/betterproto2/tests/test_enum.py b/betterproto2/tests/test_enum.py index 6dc12f78..7c28f9df 100644 --- a/betterproto2/tests/test_enum.py +++ b/betterproto2/tests/test_enum.py @@ -92,13 +92,13 @@ def test_enum_to_dict() -> None: no_striping=NoStriping.NO_STRIPING_A, ) - print(ArithmeticOperator.PLUS.proto_name) - assert msg.to_dict() == { "arithmeticOperator": "ARITHMETIC_OPERATOR_PLUS", # The original proto name must be preserved "noStriping": "NO_STRIPING_A", } + assert EnumMessage.from_dict(msg.to_dict()) == msg + def test_unknown_variant_to_dict() -> None: from tests.outputs.enum.enum import NewVersion, NewVersionMessage, OldVersionMessage diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index 3ed23621..36d78036 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -32,13 +32,25 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum): {% endif %} {% if enum.has_renamed_entries %} - betterproto_proto_names = { - {% for entry in enum.entries %} - {% if entry.proto_name != entry.name %} - {{ entry.value }}: "{{ entry.proto_name }}", - {% endif %} - {% endfor %} - } + @classmethod + def betterproto_value_to_renamed_proto_names(cls) -> dict[int, str]: + return { + {% for entry in enum.entries %} + {% if entry.proto_name != entry.name %} + {{ entry.value }}: "{{ entry.proto_name }}", + {% endif %} + {% endfor %} + } + + @classmethod + def betterproto_renamed_proto_names_to_value(cls) -> dict[str, int]: + return { + {% for entry in enum.entries %} + {% if entry.proto_name != entry.name %} + "{{ entry.proto_name }}": {{ entry.value }}, + {% endif %} + {% endfor %} + } {% endif %} {% endfor %}