diff --git a/betterproto2_compiler/src/betterproto2_compiler/known_types/any.py b/betterproto2_compiler/src/betterproto2_compiler/known_types/any.py index 57a23a9..8184f34 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/known_types/any.py +++ b/betterproto2_compiler/src/betterproto2_compiler/known_types/any.py @@ -1,6 +1,7 @@ import typing import betterproto2 +from typing_extensions import Self from betterproto2_compiler.lib.google.protobuf import Any as VanillaAny @@ -60,7 +61,7 @@ def to_dict(self, **kwargs) -> dict[str, typing.Any]: # TODO typing @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: value = dict(value) # Make a copy type_url = value.pop("@type", None) diff --git a/betterproto2_compiler/src/betterproto2_compiler/known_types/duration.py b/betterproto2_compiler/src/betterproto2_compiler/known_types/duration.py index e6dcda1..c5a8757 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/known_types/duration.py +++ b/betterproto2_compiler/src/betterproto2_compiler/known_types/duration.py @@ -3,6 +3,7 @@ import typing import betterproto2 +from typing_extensions import Self from betterproto2_compiler.lib.google.protobuf import Duration as VanillaDuration @@ -30,13 +31,13 @@ def delta_to_json(delta: datetime.timedelta) -> str: # TODO typing @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, str): if not re.match(r"^\d+(\.\d+)?s$", value): raise ValueError(f"Invalid duration string: {value}") seconds = float(value[:-1]) - return Duration(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9)) + return cls(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9)) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) diff --git a/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py b/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py index e543901..8c95566 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py +++ b/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py @@ -1,6 +1,7 @@ import typing import betterproto2 +from typing_extensions import Self from betterproto2_compiler.lib.google.protobuf import ( BoolValue as VanillaBoolValue, @@ -24,9 +25,9 @@ def to_wrapped(self) -> bool: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, bool): - return BoolValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -48,9 +49,9 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return Int32Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -72,9 +73,9 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return Int64Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -96,9 +97,9 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return UInt32Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -120,9 +121,9 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return UInt64Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -144,9 +145,9 @@ def to_wrapped(self) -> float: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, float): - return FloatValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -168,9 +169,9 @@ def to_wrapped(self) -> float: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, float): - return DoubleValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -192,9 +193,9 @@ def to_wrapped(self) -> str: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, str): - return StringValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -216,9 +217,9 @@ def to_wrapped(self) -> bytes: return self.value @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, bytes): - return BytesValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( diff --git a/betterproto2_compiler/src/betterproto2_compiler/known_types/struct.py b/betterproto2_compiler/src/betterproto2_compiler/known_types/struct.py index 6bcffad..96e4bfb 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/known_types/struct.py +++ b/betterproto2_compiler/src/betterproto2_compiler/known_types/struct.py @@ -1,6 +1,7 @@ import typing import betterproto2 +from typing_extensions import Self from betterproto2_compiler.lib.google.protobuf import ( ListValue as VanillaListValue, @@ -13,7 +14,7 @@ class Struct(VanillaStruct): # TODO typing @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: assert isinstance(value, dict) fields: dict[str, Value] = {} @@ -47,7 +48,7 @@ def to_dict( class Value(VanillaValue): # TODO typing @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: match value: case bool() as b: return cls(bool_value=b) @@ -94,7 +95,7 @@ def to_dict( class ListValue(VanillaListValue): # TODO typing @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: return cls(values=[Value.from_dict(v) for v in value]) # TODO typing diff --git a/betterproto2_compiler/src/betterproto2_compiler/known_types/timestamp.py b/betterproto2_compiler/src/betterproto2_compiler/known_types/timestamp.py index 3620c01..bd8f6e7 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/known_types/timestamp.py +++ b/betterproto2_compiler/src/betterproto2_compiler/known_types/timestamp.py @@ -3,13 +3,14 @@ import betterproto2 import dateutil.parser +from typing_extensions import Self from betterproto2_compiler.lib.google.protobuf import Timestamp as VanillaTimestamp class Timestamp(VanillaTimestamp): @classmethod - def from_datetime(cls, dt: datetime.datetime) -> "Timestamp": + def from_datetime(cls, dt: datetime.datetime) -> Self: if not dt.tzinfo: raise ValueError("datetime must be timezone aware") @@ -55,11 +56,11 @@ def timestamp_to_json(dt: datetime.datetime) -> str: # TODO typing @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, str): dt = dateutil.parser.isoparse(value) dt = dt.astimezone(datetime.timezone.utc) - return Timestamp.from_datetime(dt) + return cls.from_datetime(dt) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) diff --git a/betterproto2_compiler/src/betterproto2_compiler/lib/google/protobuf/__init__.py b/betterproto2_compiler/src/betterproto2_compiler/lib/google/protobuf/__init__.py index bb21770..bee52d7 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/lib/google/protobuf/__init__.py +++ b/betterproto2_compiler/src/betterproto2_compiler/lib/google/protobuf/__init__.py @@ -89,6 +89,7 @@ import datetime import re import typing +from typing_extensions import Self import warnings from dataclasses import dataclass @@ -917,7 +918,7 @@ def to_dict(self, **kwargs) -> dict[str, typing.Any]: return output @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: value = dict(value) # Make a copy type_url = value.pop("@type", None) @@ -1024,9 +1025,9 @@ class BoolValue(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, bool): - return BoolValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -1063,9 +1064,9 @@ class BytesValue(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, bytes): - return BytesValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -1182,9 +1183,9 @@ class DoubleValue(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, float): - return DoubleValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -1308,13 +1309,13 @@ def delta_to_json(delta: datetime.timedelta) -> str: return f"{'.'.join(parts)}s" @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, str): if not re.match(r"^\d+(\.\d+)?s$", value): raise ValueError(f"Invalid duration string: {value}") seconds = float(value[:-1]) - return Duration(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9)) + return cls(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9)) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) @@ -2599,9 +2600,9 @@ class FloatValue(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, float): - return FloatValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -2690,9 +2691,9 @@ class Int32Value(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return Int32Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -2729,9 +2730,9 @@ class Int64Value(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return Int64Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -2768,7 +2769,7 @@ class ListValue(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: return cls(values=[Value.from_dict(v) for v in value]) def to_dict( @@ -3393,9 +3394,9 @@ class StringValue(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, str): - return StringValue(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -3439,7 +3440,7 @@ class Struct(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: assert isinstance(value, dict) fields: dict[str, Value] = {} @@ -3580,7 +3581,7 @@ class Timestamp(betterproto2.Message): """ @classmethod - def from_datetime(cls, dt: datetime.datetime) -> "Timestamp": + def from_datetime(cls, dt: datetime.datetime) -> Self: if not dt.tzinfo: raise ValueError("datetime must be timezone aware") @@ -3625,11 +3626,11 @@ def timestamp_to_json(dt: datetime.datetime) -> str: return f"{result}.{nanos:09d}" @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, str): dt = dateutil.parser.isoparse(value) dt = dt.astimezone(datetime.timezone.utc) - return Timestamp.from_datetime(dt) + return cls.from_datetime(dt) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) @@ -3643,7 +3644,7 @@ def to_dict( # If the output format is PYTHON, we should have kept the wraped type without building the real class assert output_format == betterproto2.OutputFormat.PROTO_JSON - return Timestamp.timestamp_to_json(self.to_datetime()) + return self.timestamp_to_json(self.to_datetime()) @staticmethod def from_wrapped(wrapped: datetime.datetime) -> "Timestamp": @@ -3715,9 +3716,9 @@ class UInt32Value(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return UInt32Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -3754,9 +3755,9 @@ class UInt64Value(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: if isinstance(value, int): - return UInt64Value(value=value) + return cls(value=value) return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( @@ -3875,7 +3876,7 @@ class Value(betterproto2.Message): """ @classmethod - def from_dict(cls, value, *, ignore_unknown_fields: bool = False): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self: match value: case bool() as b: return cls(bool_value=b) diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index e21afa5..750e1f6 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -19,6 +19,7 @@ import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator import typing from typing import TYPE_CHECKING +from typing_extensions import Self {% if output_file.settings.pydantic_dataclasses %} import pydantic