diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index 1187be7d..b62da178 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -59,7 +59,7 @@ build-backend = "hatchling.build" # ] [tool.ruff] -extend-exclude = ["tests/output_*", "src/betterproto2/internal_lib"] +extend-exclude = ["tests/outputs", "src/betterproto2/internal_lib"] target-version = "py310" line-length = 120 diff --git a/betterproto2/src/betterproto2/__init__.py b/betterproto2/src/betterproto2/__init__.py index 7ceda2f0..24a02926 100644 --- a/betterproto2/src/betterproto2/__init__.py +++ b/betterproto2/src/betterproto2/__init__.py @@ -34,7 +34,7 @@ from ._types import T from ._version import __version__, check_compiler_version from .casing import camel_case, safe_snake_case, snake_case -from .enum import Enum as Enum +from .enum_ import Enum as Enum from .grpc.grpclib_client import ServiceStub as ServiceStub from .utils import classproperty @@ -585,9 +585,10 @@ def _value_to_dict( if proto_type in INT_64_TYPES: return str(value), not bool(value) if proto_type == TYPE_BYTES: - return b64encode(value).decode("utf8"), not (bool(value)) + return b64encode(value).decode("utf8"), not bool(value) if proto_type == TYPE_ENUM: - return field_type(value).name, not bool(value) + enum_value = field_type(value) + return enum_value.proto_name or enum_value.name, not bool(value) if proto_type in (TYPE_FLOAT, TYPE_DOUBLE): return _dump_float(value), not bool(value) return value, not bool(value) diff --git a/betterproto2/src/betterproto2/enum.py b/betterproto2/src/betterproto2/enum_.py similarity index 50% rename from betterproto2/src/betterproto2/enum.py rename to betterproto2/src/betterproto2/enum_.py index 75446edb..081a6078 100644 --- a/betterproto2/src/betterproto2/enum.py +++ b/betterproto2/src/betterproto2/enum_.py @@ -1,14 +1,42 @@ -from enum import IntEnum +import sys +from enum import EnumMeta, IntEnum from typing_extensions import Self -class Enum(IntEnum): +class _EnumMeta(EnumMeta): + def __new__(metacls, cls, bases, classdict): + # Find the proto names if defined + if sys.version_info >= (3, 11): + proto_names = classdict.pop("betterproto_proto_names", {}) + classdict._member_names.pop("betterproto_proto_names", None) + else: + proto_names = {} + if "betterproto_proto_names" in classdict: + proto_names = classdict.pop("betterproto_proto_names") + classdict._member_names.remove("betterproto_proto_names") + + enum_class = super().__new__(metacls, cls, bases, classdict) + + # Attach extra info to each enum member + for member in enum_class: + value = member.value # type: ignore[reportAttributeAccessIssue] + extra = proto_names.get(value) + member._proto_name = extra # type: ignore[reportAttributeAccessIssue] + + return enum_class + + +class Enum(IntEnum, metaclass=_EnumMeta): + @property + def proto_name(self) -> str | None: + return self._proto_name # type: ignore[reportAttributeAccessIssue] + @classmethod def _missing_(cls, value): # If the given value is not an integer, let the standard enum implementation raise an error if not isinstance(value, int): - return None + return # Create a new "unknown" instance with the given value. obj = int.__new__(cls, value) diff --git a/betterproto2/tests/test_all_definition.py b/betterproto2/tests/test_all_definition.py index 92ed50b2..eb2ec30d 100644 --- a/betterproto2/tests/test_all_definition.py +++ b/betterproto2/tests/test_all_definition.py @@ -17,4 +17,4 @@ def test_all_definition(): "TestSyncStub", "ThingType", ) - assert enum.__all__ == ("ArithmeticOperator", "Choice", "HttpCode", "NoStriping", "Test") + assert enum.__all__ == ("ArithmeticOperator", "Choice", "EnumMessage", "HttpCode", "NoStriping", "Test") diff --git a/betterproto2/tests/test_enum.py b/betterproto2/tests/test_enum.py index b2dcfed2..f8e0bc17 100644 --- a/betterproto2/tests/test_enum.py +++ b/betterproto2/tests/test_enum.py @@ -82,3 +82,19 @@ def test_enum_renaming() -> None: assert set(ArithmeticOperator.__members__) == {"NONE", "PLUS", "MINUS", "_0_PREFIXED"} assert set(HttpCode.__members__) == {"UNSPECIFIED", "OK", "NOT_FOUND"} assert set(NoStriping.__members__) == {"NO_STRIPING_NONE", "NO_STRIPING_A", "B"} + + +def test_enum_to_dict() -> None: + from tests.outputs.enum.enum import ArithmeticOperator, EnumMessage, NoStriping + + msg = EnumMessage( + arithmetic_operator=ArithmeticOperator.PLUS, + no_striping=NoStriping.NO_STRIPING_A, + ) + + print(ArithmeticOperator.PLUS.proto_name) + + assert msg.to_dict() == { + "arithmeticOperator": "ARITHMETIC_OPERATOR_PLUS", # The original proto name must be preserved + "noStriping": "NO_STRIPING_A", + } diff --git a/betterproto2_compiler/pyproject.toml b/betterproto2_compiler/pyproject.toml index 6181a640..fb1de78b 100644 --- a/betterproto2_compiler/pyproject.toml +++ b/betterproto2_compiler/pyproject.toml @@ -60,7 +60,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.ruff] -extend-exclude = ["tests/output_*", "src/betterproto2_compiler/lib"] +extend-exclude = ["tests/outputs", "src/betterproto2_compiler/lib"] target-version = "py310" line-length = 120 diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 953aba56..cc65b01e 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -610,6 +610,7 @@ class EnumEntry: """Representation of an Enum entry.""" name: str + proto_name: str value: int comment: str @@ -617,6 +618,7 @@ def __post_init__(self) -> None: self.entries = [ self.EnumEntry( name=entry_proto_value.name, + proto_name=entry_proto_value.name, value=entry_proto_value.number, comment=get_comment(proto_file=self.source_file, path=self.path + [2, entry_number]), ) @@ -672,6 +674,10 @@ def descriptor_name(self) -> str: """ return self.output_file.get_descriptor_name(self.source_file) + @property + def has_renamed_entries(self) -> bool: + return any(entry.proto_name != entry.name for entry in self.entries) + @dataclass(kw_only=True) class ServiceCompiler(ProtoContentBase): diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index 09562fb5..3ed23621 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -31,6 +31,16 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum): return core_schema.int_schema(ge=0) {% endif %} + {% if enum.has_renamed_entries %} + betterproto_proto_names = { + {% for entry in enum.entries %} + {% if entry.proto_name != entry.name %} + {{ entry.value }}: "{{ entry.proto_name }}", + {% endif %} + {% endfor %} + } + {% endif %} + {% endfor %} {% for _, message in output_file.messages|dictsort(by="key") %} {% if output_file.settings.pydantic_dataclasses %} diff --git a/betterproto2_compiler/tests/inputs/enum/enum.proto b/betterproto2_compiler/tests/inputs/enum/enum.proto index d37133a6..fb2aa9fc 100644 --- a/betterproto2_compiler/tests/inputs/enum/enum.proto +++ b/betterproto2_compiler/tests/inputs/enum/enum.proto @@ -4,16 +4,16 @@ package enum; // Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values message Test { - Choice choice = 1; - repeated Choice choices = 2; + Choice choice = 1; + repeated Choice choices = 2; } enum Choice { - ZERO = 0; - ONE = 1; - // TWO = 2; - FOUR = 4; - THREE = 3; + ZERO = 0; + ONE = 1; + // TWO = 2; + FOUR = 4; + THREE = 3; } // A "C" like enum with the enum name prefixed onto members, these should be stripped @@ -38,3 +38,8 @@ enum HTTPCode { HTTP_CODE_OK = 200; HTTP_CODE_NOT_FOUND = 404; } + +message EnumMessage { + ArithmeticOperator arithmetic_operator = 1; + NoStriping no_striping = 2; +}