Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions betterproto2/src/betterproto2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ def _value_from_dict(value: Any, meta: FieldMetadata, field_type: type) -> Any:

if meta.proto_type == TYPE_ENUM:
if isinstance(value, str):
if (int_value := field_type.betterproto_renamed_proto_names_to_value().get(value)) is not None:
return field_type(int_value)
return field_type.from_string(value)
if isinstance(value, int):
return field_type(value)
Expand Down
20 changes: 9 additions & 11 deletions betterproto2/src/betterproto2/enum_.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import sys
from enum import EnumMeta, IntEnum

from typing_extensions import Self


class _EnumMeta(EnumMeta):
def __new__(metacls, cls, bases, classdict):
# Find the proto names if defined
if sys.version_info >= (3, 11):
proto_names = classdict.pop("betterproto_proto_names", {})
classdict._member_names.pop("betterproto_proto_names", None)
else:
proto_names = {}
if "betterproto_proto_names" in classdict:
proto_names = classdict.pop("betterproto_proto_names")
classdict._member_names.remove("betterproto_proto_names")

enum_class = super().__new__(metacls, cls, bases, classdict)
proto_names = enum_class.betterproto_value_to_renamed_proto_names() # type: ignore[reportAttributeAccessIssue]

# Attach extra info to each enum member
for member in enum_class:
Expand All @@ -32,6 +22,14 @@ class Enum(IntEnum, metaclass=_EnumMeta):
def proto_name(self) -> str | None:
return self._proto_name # type: ignore[reportAttributeAccessIssue]

@classmethod
def betterproto_value_to_renamed_proto_names(cls) -> dict[int, str]:
return {}

@classmethod
def betterproto_renamed_proto_names_to_value(cls) -> dict[str, int]:
return {}

@classmethod
def _missing_(cls, value):
# If the given value is not an integer, let the standard enum implementation raise an error
Expand Down
4 changes: 2 additions & 2 deletions betterproto2/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def test_enum_to_dict() -> None:
no_striping=NoStriping.NO_STRIPING_A,
)

print(ArithmeticOperator.PLUS.proto_name)

assert msg.to_dict() == {
"arithmeticOperator": "ARITHMETIC_OPERATOR_PLUS", # The original proto name must be preserved
"noStriping": "NO_STRIPING_A",
}

assert EnumMessage.from_dict(msg.to_dict()) == msg


def test_unknown_variant_to_dict() -> None:
from tests.outputs.enum.enum import NewVersion, NewVersionMessage, OldVersionMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,25 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum):
{% endif %}

{% if enum.has_renamed_entries %}
betterproto_proto_names = {
{% for entry in enum.entries %}
{% if entry.proto_name != entry.name %}
{{ entry.value }}: "{{ entry.proto_name }}",
{% endif %}
{% endfor %}
}
@classmethod
def betterproto_value_to_renamed_proto_names(cls) -> dict[int, str]:
return {
{% for entry in enum.entries %}
{% if entry.proto_name != entry.name %}
{{ entry.value }}: "{{ entry.proto_name }}",
{% endif %}
{% endfor %}
}

@classmethod
def betterproto_renamed_proto_names_to_value(cls) -> dict[str, int]:
return {
{% for entry in enum.entries %}
{% if entry.proto_name != entry.name %}
"{{ entry.proto_name }}": {{ entry.value }},
{% endif %}
{% endfor %}
}
{% endif %}

{% endfor %}
Expand Down
Loading