diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 103ce775..3b9bab9b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: ruff-format args: ["--diff", "src", "tests"] - id: ruff - args: ["--select", "I", "src", "tests"] + args: ["src", "tests"] - repo: https://github.com/PyCQA/doc8 rev: 0.10.1 @@ -16,10 +16,3 @@ repos: - id: doc8 additional_dependencies: - toml - - - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.10.0 - hooks: - - id: pretty-format-java - args: [--autofix, --aosp] - files: ^.*\.java$ diff --git a/docs/conf.py b/docs/conf.py index 564208fe..1f7a5fbb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,6 @@ import toml - # -- Project information ----------------------------------------------------- project = "betterproto" diff --git a/pyproject.toml b/pyproject.toml index 4f693029..0990b869 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,10 +53,25 @@ rust-codec = ["betterproto-rust-codec"] [tool.ruff] extend-exclude = ["tests/output_*"] target-version = "py38" +line-length = 120 + +[tool.ruff.lint] +select = [ + "F401", # Unused imports + "F841", # Unused local variables + "F821", # Undefined names + "E501", # Line length violations + + "SIM101", # Simplify unnecessary if-else blocks + "SIM102", # Simplify return or yield statements + "SIM103", # Simplify list/set/dict comprehensions + + "I", +] + [tool.ruff.lint.isort] combine-as-imports = true -lines-after-imports = 2 # Dev workflow tasks diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index cf54e0e6..dfa9dd5e 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +__all__ = ["__version__"] + import dataclasses import enum as builtin_enum import json @@ -55,7 +57,6 @@ hybridmethod, ) - if TYPE_CHECKING: from _typeshed import ( SupportsRead, @@ -211,11 +212,7 @@ def dataclass_field( return dataclasses.field( default_factory=default_factory, - metadata={ - "betterproto": FieldMetadata( - number, proto_type, map_types, group, wraps, optional - ) - }, + metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps, optional)}, ) @@ -263,9 +260,7 @@ def int32_field( optional: bool = False, repeated: bool = False, ) -> Any: - return dataclass_field( - number, TYPE_INT32, int, group=group, optional=optional, repeated=repeated - ) + return dataclass_field(number, TYPE_INT32, int, group=group, optional=optional, repeated=repeated) def int64_field( @@ -274,9 +269,7 @@ def int64_field( optional: bool = False, repeated: bool = False, ) -> Any: - return dataclass_field( - number, TYPE_INT64, int, group=group, optional=optional, repeated=repeated - ) + return dataclass_field(number, TYPE_INT64, int, group=group, optional=optional, repeated=repeated) def uint32_field( @@ -489,12 +482,8 @@ def message_field( ) -def map_field( - number: int, key_type: str, value_type: str, group: Optional[str] = None -) -> Any: - return dataclass_field( - number, TYPE_MAP, dict, map_types=(key_type, value_type), group=group - ) +def map_field(number: int, key_type: str, value_type: str, group: Optional[str] = None) -> Any: + return dataclass_field(number, TYPE_MAP, dict, map_types=(key_type, value_type), group=group) def _pack_fmt(proto_type: str) -> str: @@ -513,7 +502,7 @@ def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None: """Encodes a single varint and dumps it into the provided stream.""" if value < -(1 << 63): raise ValueError( - "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes." + "Negative value is not representable as a 64-bit integer" " - unable to encode a varint within 10 bytes." ) elif value < 0: value += 1 << 64 @@ -728,9 +717,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: elif wire_type == WIRE_FIXED_32: decoded, i = value[i : i + 4], i + 4 - yield ParsedField( - number=number, wire_type=wire_type, value=decoded, raw=value[start:i] - ) + yield ParsedField(number=number, wire_type=wire_type, value=decoded, raw=value[start:i]) class ProtoClassMetadata: @@ -775,22 +762,16 @@ def __init__(self, cls: Type["Message"]): self.oneof_field_by_group = by_group self.field_name_by_number = by_field_number self.meta_by_field_name = by_field_name - self.sorted_field_names = tuple( - by_field_number[number] for number in sorted(by_field_number) - ) + self.sorted_field_names = tuple(by_field_number[number] for number in sorted(by_field_number)) self.default_gen = self._get_default_gen(cls, fields) self.cls_by_field = self._get_cls_by_field(cls, fields) @staticmethod - def _get_default_gen( - cls: Type["Message"], fields: Iterable[dataclasses.Field] - ) -> Dict[str, Callable[[], Any]]: + def _get_default_gen(cls: Type["Message"], fields: Iterable[dataclasses.Field]) -> Dict[str, Callable[[], Any]]: return {field.name: field.default_factory for field in fields} @staticmethod - def _get_cls_by_field( - cls: Type["Message"], fields: Iterable[dataclasses.Field] - ) -> Dict[str, Type]: + def _get_cls_by_field(cls: Type["Message"], fields: Iterable[dataclasses.Field]) -> Dict[str, Type]: field_cls = {} for field in fields: @@ -970,8 +951,8 @@ def __bytes__(self) -> bytes: item, wraps=meta.wraps or "", ) - # if it's an empty message it still needs to be represented - # as an item in the repeated list + # if it's an empty message it still needs to be + # represented as an item in the repeated list or b"\n\x00" ) @@ -980,9 +961,7 @@ def __bytes__(self) -> bytes: assert meta.map_types sk = _serialize_single(1, meta.map_types[0], k) sv = _serialize_single(2, meta.map_types[1], v) - stream.write( - _serialize_single(meta.number, meta.proto_type, sk + sv) - ) + stream.write(_serialize_single(meta.number, meta.proto_type, sk + sv)) else: stream.write( _serialize_single( @@ -1034,9 +1013,8 @@ def _type_hints(cls) -> Dict[str, Type]: def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: """Get the message class for a field from the type hints.""" field_cls = cls._type_hint(field.name) - if hasattr(field_cls, "__args__") and index >= 0: - if field_cls.__args__ is not None: - field_cls = field_cls.__args__[index] + if hasattr(field_cls, "__args__") and index >= 0 and field_cls.__args__ is not None: + field_cls = field_cls.__args__[index] return field_cls def _get_field_default(self, field_name: str) -> Any: @@ -1045,9 +1023,7 @@ def _get_field_default(self, field_name: str) -> Any: warnings.filterwarnings("ignore", category=DeprecationWarning) return self._betterproto.default_gen[field_name]() - def _postprocess_single( - self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any - ) -> Any: + def _postprocess_single(self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any) -> Any: """Adjusts values after parsing.""" if wire_type == WIRE_VARINT: if meta.proto_type in (TYPE_INT32, TYPE_INT64): @@ -1141,14 +1117,10 @@ def load( else: decoded, pos = decode_varint(parsed.value, pos) wire_type = WIRE_VARINT - decoded = self._postprocess_single( - wire_type, meta, field_name, decoded - ) + decoded = self._postprocess_single(wire_type, meta, field_name, decoded) value.append(decoded) else: - value = self._postprocess_single( - parsed.wire_type, meta, field_name, parsed.value - ) + value = self._postprocess_single(parsed.wire_type, meta, field_name, parsed.value) current = getattr(self, field_name) @@ -1227,9 +1199,7 @@ def FromString(cls: Type[T], data: bytes) -> T: """ return cls().parse(data) - def to_dict( - self, casing: Casing = Casing.CAMEL, include_default_values: bool = False - ) -> Dict[str, Any]: + def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> Dict[str, Any]: """ Returns a JSON serializable dict representation of this object. @@ -1271,9 +1241,7 @@ def to_dict( elif cls == timedelta: value = [_Duration.delta_to_json(i) for i in value] else: - value = [ - i.to_dict(casing, include_default_values) for i in value - ] + value = [i.to_dict(casing, include_default_values) for i in value] if value or include_default_values: output[cased_name] = value elif value is None: @@ -1300,9 +1268,7 @@ def to_dict( output[cased_name] = str(value) elif meta.proto_type == TYPE_BYTES: if field_is_repeated: - output[cased_name] = [ - b64encode(b).decode("utf8") for b in value - ] + output[cased_name] = [b64encode(b).decode("utf8") for b in value] elif value is None and include_default_values: output[cased_name] = value else: @@ -1310,9 +1276,7 @@ def to_dict( elif meta.proto_type == TYPE_ENUM: if field_is_repeated: enum_class = field_types[field_name].__args__[0] - if isinstance(value, typing.Iterable) and not isinstance( - value, str - ): + if isinstance(value, typing.Iterable) and not isinstance(value, str): output[cased_name] = [enum_class(el).name for el in value] else: # transparently upgrade single value to repeated @@ -1350,11 +1314,7 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: if meta.proto_type == TYPE_MESSAGE: sub_cls = cls._betterproto.cls_by_field[field_name] if sub_cls == datetime: - value = ( - [isoparse(item) for item in value] - if isinstance(value, list) - else isoparse(value) - ) + value = [isoparse(item) for item in value] if isinstance(value, list) else isoparse(value) elif sub_cls == timedelta: value = ( [timedelta(seconds=float(item[:-1])) for item in value] @@ -1372,17 +1332,9 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: value = {k: sub_cls.from_dict(v) for k, v in value.items()} else: if meta.proto_type in INT_64_TYPES: - value = ( - [int(n) for n in value] - if isinstance(value, list) - else int(value) - ) + value = [int(n) for n in value] if isinstance(value, list) else int(value) elif meta.proto_type == TYPE_BYTES: - value = ( - [b64decode(n) for n in value] - if isinstance(value, list) - else b64decode(value) - ) + value = [b64decode(n) for n in value] if isinstance(value, list) else b64decode(value) elif meta.proto_type == TYPE_ENUM: enum_cls = cls._betterproto.cls_by_field[field_name] if isinstance(value, list): @@ -1390,11 +1342,7 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: elif isinstance(value, str): value = enum_cls.from_string(value) elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): - value = ( - [_parse_float(n) for n in value] - if isinstance(value, list) - else _parse_float(value) - ) + value = [_parse_float(n) for n in value] if isinstance(value, list) else _parse_float(value) init_kwargs[field_name] = value return init_kwargs @@ -1494,9 +1442,7 @@ def from_json(self: T, value: Union[str, bytes]) -> T: """ return self.from_dict(json.loads(value)) - def to_pydict( - self, casing: Casing = Casing.CAMEL, include_default_values: bool = False - ) -> Dict[str, Any]: + def to_pydict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> Dict[str, Any]: """ Returns a python dict representation of this object. @@ -1526,18 +1472,14 @@ def to_pydict( if ( value != DATETIME_ZERO or include_default_values - or self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) + or self._include_default_value_for_oneof(field_name=field_name, meta=meta) ): output[cased_name] = value elif isinstance(value, timedelta): if ( value != timedelta(0) or include_default_values - or self._include_default_value_for_oneof( - field_name=field_name, meta=meta - ) + or self._include_default_value_for_oneof(field_name=field_name, meta=meta) ): output[cased_name] = value elif meta.wraps: @@ -1639,21 +1581,16 @@ def _validate_field_groups(cls, values): field_name = field.name meta = field_name_to_meta[field_name] - # This is a synthetic oneof; we should ignore it's presence and not consider it as a oneof. + # This is a synthetic oneof; we should ignore it's presence and not + # consider it as a oneof. if meta.optional: continue - set_fields = [ - field.name - for field in field_set - if getattr(values, field.name, None) is not None - ] + set_fields = [field.name for field in field_set if getattr(values, field.name, None) is not None] if len(set_fields) > 1: set_fields_str = ", ".join(set_fields) - raise ValueError( - f"Group {group} has more than one value; fields {set_fields_str} are not None" - ) + raise ValueError(f"Group {group} has more than one value;" f" fields {set_fields_str} are not None") return values @@ -1676,9 +1613,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]] if v is not None: if field_name: - raise RuntimeError( - f"more than one field set in oneof: {field.name} and {field_name}" - ) + raise RuntimeError(f"more than one field set in oneof: {field.name} and {field_name}") field_name, value = field.name, v return field_name, value @@ -1703,9 +1638,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]] class _Duration(Duration): @classmethod - def from_timedelta( - cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1) - ) -> "_Duration": + def from_timedelta(cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)) -> "_Duration": total_ms = delta // _1_microsecond seconds = int(total_ms / 1e6) nanos = int((total_ms % 1e6) * 1e3) @@ -1732,9 +1665,7 @@ def from_datetime(cls, dt: datetime) -> "_Timestamp": offset = dt - DATETIME_ZERO # below is the same as timedelta.total_seconds() but without dividing by 1e6 # so we end up with microseconds as integers instead of seconds as float - offset_us = ( - offset.days * 24 * 60 * 60 + offset.seconds - ) * 10**6 + offset.microseconds + offset_us = (offset.days * 24 * 60 * 60 + offset.seconds) * 10**6 + offset.microseconds seconds, us = divmod(offset_us, 10**6) return cls(seconds, us * 1000) diff --git a/src/betterproto/_types.py b/src/betterproto/_types.py index 616d550d..dd91dce9 100644 --- a/src/betterproto/_types.py +++ b/src/betterproto/_types.py @@ -3,7 +3,6 @@ TypeVar, ) - if TYPE_CHECKING: from grpclib._typing import IProtoMessage diff --git a/src/betterproto/_version.py b/src/betterproto/_version.py index 6b9a1ae6..353f700a 100644 --- a/src/betterproto/_version.py +++ b/src/betterproto/_version.py @@ -1,4 +1,3 @@ from importlib import metadata - __version__ = metadata.version("betterproto") diff --git a/src/betterproto/casing.py b/src/betterproto/casing.py index f7d0832b..741adf76 100644 --- a/src/betterproto/casing.py +++ b/src/betterproto/casing.py @@ -1,7 +1,6 @@ import keyword import re - # Word delimiters and symbols that will not be preserved when re-casing. # language=PythonRegExp SYMBOLS = "[^a-zA-Z0-9]*" @@ -47,9 +46,7 @@ def substitute_word(symbols: str, word: str, is_start: bool) -> str: elif is_start: delimiter_count = len(symbols) elif word.isupper() or word.islower(): - delimiter_count = max( - 1, len(symbols) - ) # Preserve all delimiters if not strict. + delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict. else: delimiter_count = len(symbols) + 1 # Extra underscore for leading capital. diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index 90ea035b..82bd0a21 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import re from typing import ( TYPE_CHECKING, Dict, @@ -15,7 +14,6 @@ from ..lib.google import protobuf as google_protobuf from .naming import pythonize_class_name - if TYPE_CHECKING: from ..plugin.models import PluginRequestCompiler from ..plugin.typing_compiler import TypingCompiler @@ -33,16 +31,14 @@ } -def parse_source_type_name( - field_type_name: str, request: "PluginRequestCompiler" -) -> Tuple[str, str]: +def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]: """ Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') - The function goes through the symbols that have been defined (names, enums, packages) to find the actual package and - name of the object that is referenced. + The function goes through the symbols that have been defined (names, enums, + packages) to find the actual package and name of the object that is referenced. """ if field_type_name[0] != ".": raise RuntimeError("relative names are not supported") @@ -58,12 +54,16 @@ def parse_source_type_name( for i in range(len(parts)): package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:]) - if package := request.output_packages.get(package_name): - if object_name in package.messages or object_name in package.enums: - if answer: - # This should have already been handeled by protoc - raise ValueError(f"ambiguous definition: {field_type_name}") - answer = package_name, object_name + package = request.output_packages.get(package_name) + + if not package: + continue + + if object_name in package.messages or object_name in package.enums: + if answer: + # This should have already been handeled by protoc + raise ValueError(f"ambiguous definition: {field_type_name}") + answer = package_name, object_name if answer: return answer @@ -105,9 +105,7 @@ def get_type_reference( compiling_google_protobuf = current_package == ["google", "protobuf"] importing_google_protobuf = py_package == ["google", "protobuf"] if importing_google_protobuf and not compiling_google_protobuf: - py_package = ( - ["betterproto", "lib"] + (["pydantic"] if pydantic else []) + py_package - ) + py_package = ["betterproto", "lib"] + (["pydantic"] if pydantic else []) + py_package if py_package[:1] == ["betterproto"]: return reference_absolute(imports, py_package, py_type) @@ -141,9 +139,7 @@ def reference_sibling(py_type: str) -> str: return f"{py_type}" -def reference_descendent( - current_package: List[str], imports: Set[str], py_package: List[str], py_type: str -) -> str: +def reference_descendent(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: """ Returns a reference to a python type in a package that is a descendent of the current package, and adds the required import that is aliased to avoid name @@ -161,9 +157,7 @@ def reference_descendent( return f"{string_import}.{py_type}" -def reference_ancestor( - current_package: List[str], imports: Set[str], py_package: List[str], py_type: str -) -> str: +def reference_ancestor(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: """ Returns a reference to a python type in a package which is an ancestor to the current package, and adds the required import that is aliased (if possible) to avoid @@ -184,24 +178,16 @@ def reference_ancestor( return string_alias -def reference_cousin( - current_package: List[str], imports: Set[str], py_package: List[str], py_type: str -) -> str: +def reference_cousin(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str: """ Returns a reference to a python type in a package that is not descendent, ancestor or sibling, and adds the required import that is aliased to avoid name conflicts. """ shared_ancestry = os.path.commonprefix([current_package, py_package]) # type: ignore distance_up = len(current_package) - len(shared_ancestry) - string_from = f".{'.' * distance_up}" + ".".join( - py_package[len(shared_ancestry) : -1] - ) + string_from = f".{'.' * distance_up}" + ".".join(py_package[len(shared_ancestry) : -1]) string_import = py_package[-1] # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34) - string_alias = ( - f"{'_' * distance_up}" - + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) - + "__" - ) + string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__" imports.add(f"from {string_from} import {string_import} as {string_alias}") return f"{string_alias}.{py_type}" diff --git a/src/betterproto/enum.py b/src/betterproto/enum.py index f5f14dcc..9a1d677b 100644 --- a/src/betterproto/enum.py +++ b/src/betterproto/enum.py @@ -13,7 +13,6 @@ Tuple, ) - if TYPE_CHECKING: from collections.abc import ( Generator, @@ -27,36 +26,27 @@ def _is_descriptor(obj: object) -> bool: - return ( - hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") - ) + return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") class EnumType(EnumMeta if TYPE_CHECKING else type): _value_map_: Mapping[int, Enum] _member_map_: Mapping[str, Enum] - def __new__( - mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any] - ) -> Self: + def __new__(mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> Self: value_map = {} member_map = {} new_mcs = type( f"{name}Type", tuple( - dict.fromkeys( - [base.__class__ for base in bases if base.__class__ is not type] - + [EnumType, type] - ) + dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type]) ), # reorder the bases so EnumType and type are last to avoid conflicts {"_value_map_": value_map, "_member_map_": member_map}, ) members = { - name: value - for name, value in namespace.items() - if not _is_descriptor(value) and not name.startswith("__") + name: value for name, value in namespace.items() if not _is_descriptor(value) and not name.startswith("__") } cls = type.__new__( @@ -139,14 +129,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}.{self.name}" def __setattr__(self, key: str, value: Any) -> Never: - raise AttributeError( - f"{self.__class__.__name__} Cannot reassign a member's attributes." - ) + raise AttributeError(f"{self.__class__.__name__} Cannot reassign a member's attributes.") def __delattr__(self, item: Any) -> Never: - raise AttributeError( - f"{self.__class__.__name__} Cannot delete a member's attributes." - ) + raise AttributeError(f"{self.__class__.__name__} Cannot delete a member's attributes.") def __copy__(self) -> Self: return self diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index b19e8061..ab24cedb 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -15,15 +15,12 @@ import grpclib.const - if TYPE_CHECKING: from grpclib.client import Channel from grpclib.metadata import Deadline from .._types import ( - ST, IProtoMessage, - Message, T, ) @@ -156,9 +153,7 @@ async def _stream_stream( **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: await stream.send_request() - sending_task = asyncio.ensure_future( - self._send_messages(stream, request_iterator) - ) + sending_task = asyncio.ensure_future(self._send_messages(stream, request_iterator)) try: async for response in stream: yield response diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py index 3e280311..61d4710b 100644 --- a/src/betterproto/grpc/grpclib_server.py +++ b/src/betterproto/grpc/grpclib_server.py @@ -3,7 +3,6 @@ from typing import ( Any, Callable, - Dict, ) import grpclib diff --git a/src/betterproto/grpc/util/async_channel.py b/src/betterproto/grpc/util/async_channel.py index 9f18dbfd..9a9345e0 100644 --- a/src/betterproto/grpc/util/async_channel.py +++ b/src/betterproto/grpc/util/async_channel.py @@ -8,7 +8,6 @@ Union, ) - T = TypeVar("T") @@ -118,9 +117,7 @@ def done(self) -> bool: # receiver per enqueued item. return self._closed and self._queue.qsize() <= self._waiting_receivers - async def send_from( - self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False - ) -> "AsyncChannel[T]": + async def send_from(self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a diff --git a/src/betterproto/lib/pydantic/google/protobuf/__init__.py b/src/betterproto/lib/pydantic/google/protobuf/__init__.py index 1d60f5b2..8a91dd2b 100644 --- a/src/betterproto/lib/pydantic/google/protobuf/__init__.py +++ b/src/betterproto/lib/pydantic/google/protobuf/__init__.py @@ -1,5 +1,4 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto # plugin: python-betterproto # This file has been @generated import warnings @@ -12,7 +11,6 @@ from betterproto import hybridmethod - if TYPE_CHECKING: from dataclasses import dataclass else: @@ -663,9 +661,7 @@ class Type(betterproto.Message): source_context: "SourceContext" = betterproto.message_field(5) """The source context.""" - syntax: "Syntax" = betterproto.enum_field( - 6, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(6, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax.""" edition: str = betterproto.string_field(7) @@ -678,9 +674,7 @@ class Type(betterproto.Message): class Field(betterproto.Message): """A single field of a message type.""" - kind: "FieldKind" = betterproto.enum_field( - 1, enum_default_value=lambda: FieldKind.try_value(0) - ) + kind: "FieldKind" = betterproto.enum_field(1, enum_default_value=lambda: FieldKind.try_value(0)) """The field type.""" cardinality: "FieldCardinality" = betterproto.enum_field( @@ -728,9 +722,7 @@ class Enum(betterproto.Message): name: str = betterproto.string_field(1) """Enum type name.""" - enumvalue: List["EnumValue"] = betterproto.message_field( - 2, wraps=betterproto.TYPE_ENUM - ) + enumvalue: List["EnumValue"] = betterproto.message_field(2, wraps=betterproto.TYPE_ENUM) """Enum value definitions.""" options: List["Option"] = betterproto.message_field(3) @@ -739,9 +731,7 @@ class Enum(betterproto.Message): source_context: "SourceContext" = betterproto.message_field(4) """The source context.""" - syntax: "Syntax" = betterproto.enum_field( - 5, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(5, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax.""" edition: str = betterproto.string_field(6) @@ -846,9 +836,7 @@ class Api(betterproto.Message): mixins: List["Mixin"] = betterproto.message_field(6) """Included interfaces. See [Mixin][].""" - syntax: "Syntax" = betterproto.enum_field( - 7, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(7, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax of the service.""" @@ -874,9 +862,7 @@ class Method(betterproto.Message): options: List["Option"] = betterproto.message_field(6) """Any metadata attached to the method.""" - syntax: "Syntax" = betterproto.enum_field( - 7, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(7, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax of this method.""" @@ -1024,9 +1010,7 @@ class FileDescriptorProto(betterproto.Message): If `edition` is present, this value must be "editions". """ - edition: "Edition" = betterproto.enum_field( - 14, enum_default_value=lambda: Edition.try_value(0) - ) + edition: "Edition" = betterproto.enum_field(14, enum_default_value=lambda: Edition.try_value(0)) """The edition of the proto file.""" @@ -1039,9 +1023,7 @@ class DescriptorProto(betterproto.Message): extension: List["FieldDescriptorProto"] = betterproto.message_field(6) nested_type: List["DescriptorProto"] = betterproto.message_field(3) enum_type: List["EnumDescriptorProto"] = betterproto.message_field(4) - extension_range: List["DescriptorProtoExtensionRange"] = betterproto.message_field( - 5 - ) + extension_range: List["DescriptorProtoExtensionRange"] = betterproto.message_field(5) oneof_decl: List["OneofDescriptorProto"] = betterproto.message_field(8) options: "MessageOptions" = betterproto.message_field(7) reserved_range: List["DescriptorProtoReservedRange"] = betterproto.message_field(9) @@ -1225,9 +1207,7 @@ class EnumDescriptorProto(betterproto.Message): name: str = betterproto.string_field(1) value: List["EnumValueDescriptorProto"] = betterproto.message_field(2) options: "EnumOptions" = betterproto.message_field(3) - reserved_range: List["EnumDescriptorProtoEnumReservedRange"] = ( - betterproto.message_field(4) - ) + reserved_range: List["EnumDescriptorProtoEnumReservedRange"] = betterproto.message_field(4) """ Range of reserved numeric values. Reserved numeric values may not be used by enum values in the same enum declaration. Reserved ranges may not @@ -1535,9 +1515,7 @@ def __post_init__(self) -> None: @dataclass(eq=False, repr=False) class FieldOptions(betterproto.Message): - ctype: "FieldOptionsCType" = betterproto.enum_field( - 1, enum_default_value=lambda: FieldOptionsCType.try_value(0) - ) + ctype: "FieldOptionsCType" = betterproto.enum_field(1, enum_default_value=lambda: FieldOptionsCType.try_value(0)) """ The ctype option instructs the C++ code generator to use a different representation of the field than it normally would. See the specific @@ -1558,9 +1536,7 @@ class FieldOptions(betterproto.Message): the behavior. """ - jstype: "FieldOptionsJsType" = betterproto.enum_field( - 6, enum_default_value=lambda: FieldOptionsJsType.try_value(0) - ) + jstype: "FieldOptionsJsType" = betterproto.enum_field(6, enum_default_value=lambda: FieldOptionsJsType.try_value(0)) """ The jstype option determines the JavaScript type used for values of the field. The option is permitted only for 64 bit integral and fixed types @@ -1641,9 +1617,7 @@ class FieldOptions(betterproto.Message): @dataclass(eq=False, repr=False) class FieldOptionsEditionDefault(betterproto.Message): - edition: "Edition" = betterproto.enum_field( - 3, enum_default_value=lambda: Edition.try_value(0) - ) + edition: "Edition" = betterproto.enum_field(3, enum_default_value=lambda: Edition.try_value(0)) value: str = betterproto.string_field(2) @@ -1837,20 +1811,14 @@ class FeatureSetDefaults(betterproto.Message): for the closest matching edition, followed by proto merges. """ - defaults: List["FeatureSetDefaultsFeatureSetEditionDefault"] = ( - betterproto.message_field(1) - ) - minimum_edition: "Edition" = betterproto.enum_field( - 4, enum_default_value=lambda: Edition.try_value(0) - ) + defaults: List["FeatureSetDefaultsFeatureSetEditionDefault"] = betterproto.message_field(1) + minimum_edition: "Edition" = betterproto.enum_field(4, enum_default_value=lambda: Edition.try_value(0)) """ The minimum supported edition (inclusive) when this was constructed. Editions before this will not have defaults. """ - maximum_edition: "Edition" = betterproto.enum_field( - 5, enum_default_value=lambda: Edition.try_value(0) - ) + maximum_edition: "Edition" = betterproto.enum_field(5, enum_default_value=lambda: Edition.try_value(0)) """ The maximum known edition (inclusive) when this was constructed. Editions after this will not have reliable defaults. @@ -1866,9 +1834,7 @@ class FeatureSetDefaultsFeatureSetEditionDefault(betterproto.Message): be used. This field must be in strict ascending order by edition. """ - edition: "Edition" = betterproto.enum_field( - 3, enum_default_value=lambda: Edition.try_value(0) - ) + edition: "Edition" = betterproto.enum_field(3, enum_default_value=lambda: Edition.try_value(0)) features: "FeatureSet" = betterproto.message_field(2) @@ -2381,9 +2347,7 @@ class Struct(betterproto.Message): The JSON representation for `Struct` is JSON object. """ - fields: Dict[str, "Value"] = betterproto.map_field( - 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE - ) + fields: Dict[str, "Value"] = betterproto.map_field(1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE) """Unordered map of dynamically typed values.""" @hybridmethod @@ -2432,27 +2396,19 @@ class Value(betterproto.Message): ) """Represents a null value.""" - number_value: Optional[float] = betterproto.double_field( - 2, optional=True, group="kind" - ) + number_value: Optional[float] = betterproto.double_field(2, optional=True, group="kind") """Represents a double value.""" - string_value: Optional[str] = betterproto.string_field( - 3, optional=True, group="kind" - ) + string_value: Optional[str] = betterproto.string_field(3, optional=True, group="kind") """Represents a string value.""" bool_value: Optional[bool] = betterproto.bool_field(4, optional=True, group="kind") """Represents a boolean value.""" - struct_value: Optional["Struct"] = betterproto.message_field( - 5, optional=True, group="kind" - ) + struct_value: Optional["Struct"] = betterproto.message_field(5, optional=True, group="kind") """Represents a structured value.""" - list_value: Optional["ListValue"] = betterproto.message_field( - 6, optional=True, group="kind" - ) + list_value: Optional["ListValue"] = betterproto.message_field(6, optional=True, group="kind") """Represents a repeated `Value`.""" @model_validator(mode="after") diff --git a/src/betterproto/lib/pydantic/google/protobuf/compiler/__init__.py b/src/betterproto/lib/pydantic/google/protobuf/compiler/__init__.py index ba16fac9..6f167546 100644 --- a/src/betterproto/lib/pydantic/google/protobuf/compiler/__init__.py +++ b/src/betterproto/lib/pydantic/google/protobuf/compiler/__init__.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING - if TYPE_CHECKING: from dataclasses import dataclass else: @@ -53,9 +52,7 @@ class CodeGeneratorRequest(betterproto.Message): parameter: str = betterproto.string_field(2) """The generator parameter passed on the command-line.""" - proto_file: List["betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"] = ( - betterproto.message_field(15) - ) + proto_file: List["betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"] = betterproto.message_field(15) """ FileDescriptorProtos for all files in files_to_generate and everything they import. The files will appear in topological order, so each file @@ -78,9 +75,9 @@ class CodeGeneratorRequest(betterproto.Message): fully qualified. """ - source_file_descriptors: List[ - "betterproto_lib_pydantic_google_protobuf.FileDescriptorProto" - ] = betterproto.message_field(17) + source_file_descriptors: List["betterproto_lib_pydantic_google_protobuf.FileDescriptorProto"] = ( + betterproto.message_field(17) + ) """ File descriptors with all options, including source-retention options. These descriptors are only provided for the files listed in @@ -195,9 +192,7 @@ class CodeGeneratorResponseFile(betterproto.Message): content: str = betterproto.string_field(15) """The file contents.""" - generated_code_info: "betterproto_lib_pydantic_google_protobuf.GeneratedCodeInfo" = betterproto.message_field( - 16 - ) + generated_code_info: "betterproto_lib_pydantic_google_protobuf.GeneratedCodeInfo" = betterproto.message_field(16) """ Information describing the file content being inserted. If an insertion point is used, this information will be appropriately offset and inserted diff --git a/src/betterproto/lib/std/google/protobuf/__init__.py b/src/betterproto/lib/std/google/protobuf/__init__.py index 76770f4b..4bb32789 100644 --- a/src/betterproto/lib/std/google/protobuf/__init__.py +++ b/src/betterproto/lib/std/google/protobuf/__init__.py @@ -1,5 +1,8 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto +# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, +# google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, +# google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, +# google/protobuf/type.proto, google/protobuf/wrappers.proto # plugin: python-betterproto # This file has been @generated @@ -511,9 +514,7 @@ class Type(betterproto.Message): source_context: "SourceContext" = betterproto.message_field(5) """The source context.""" - syntax: "Syntax" = betterproto.enum_field( - 6, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(6, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax.""" @@ -521,9 +522,7 @@ class Type(betterproto.Message): class Field(betterproto.Message): """A single field of a message type.""" - kind: "FieldKind" = betterproto.enum_field( - 1, enum_default_value=lambda: FieldKind.try_value(0) - ) + kind: "FieldKind" = betterproto.enum_field(1, enum_default_value=lambda: FieldKind.try_value(0)) """The field type.""" cardinality: "FieldCardinality" = betterproto.enum_field( @@ -571,9 +570,7 @@ class Enum(betterproto.Message): name: str = betterproto.string_field(1) """Enum type name.""" - enumvalue: List["EnumValue"] = betterproto.message_field( - 2, wraps=betterproto.TYPE_ENUM, repeated=True - ) + enumvalue: List["EnumValue"] = betterproto.message_field(2, wraps=betterproto.TYPE_ENUM, repeated=True) """Enum value definitions.""" options: List["Option"] = betterproto.message_field(3, repeated=True) @@ -582,9 +579,7 @@ class Enum(betterproto.Message): source_context: "SourceContext" = betterproto.message_field(4) """The source context.""" - syntax: "Syntax" = betterproto.enum_field( - 5, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(5, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax.""" @@ -684,9 +679,7 @@ class Api(betterproto.Message): mixins: List["Mixin"] = betterproto.message_field(6, repeated=True) """Included interfaces. See [Mixin][].""" - syntax: "Syntax" = betterproto.enum_field( - 7, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(7, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax of the service.""" @@ -712,9 +705,7 @@ class Method(betterproto.Message): options: List["Option"] = betterproto.message_field(6, repeated=True) """Any metadata attached to the method.""" - syntax: "Syntax" = betterproto.enum_field( - 7, enum_default_value=lambda: Syntax.try_value(0) - ) + syntax: "Syntax" = betterproto.enum_field(7, enum_default_value=lambda: Syntax.try_value(0)) """The source syntax of this method.""" @@ -854,16 +845,12 @@ class FileDescriptorProto(betterproto.Message): """ - service: List["ServiceDescriptorProto"] = betterproto.message_field( - 6, repeated=True - ) + service: List["ServiceDescriptorProto"] = betterproto.message_field(6, repeated=True) """ """ - extension: List["FieldDescriptorProto"] = betterproto.message_field( - 7, repeated=True - ) + extension: List["FieldDescriptorProto"] = betterproto.message_field(7, repeated=True) """ """ @@ -902,9 +889,7 @@ class DescriptorProto(betterproto.Message): """ - extension: List["FieldDescriptorProto"] = betterproto.message_field( - 6, repeated=True - ) + extension: List["FieldDescriptorProto"] = betterproto.message_field(6, repeated=True) """ """ @@ -919,16 +904,12 @@ class DescriptorProto(betterproto.Message): """ - extension_range: List["DescriptorProtoExtensionRange"] = betterproto.message_field( - 5, repeated=True - ) + extension_range: List["DescriptorProtoExtensionRange"] = betterproto.message_field(5, repeated=True) """ """ - oneof_decl: List["OneofDescriptorProto"] = betterproto.message_field( - 8, repeated=True - ) + oneof_decl: List["OneofDescriptorProto"] = betterproto.message_field(8, repeated=True) """ """ @@ -938,9 +919,7 @@ class DescriptorProto(betterproto.Message): """ - reserved_range: List["DescriptorProtoReservedRange"] = betterproto.message_field( - 9, repeated=True - ) + reserved_range: List["DescriptorProtoReservedRange"] = betterproto.message_field(9, repeated=True) """ """ @@ -987,9 +966,7 @@ class DescriptorProtoReservedRange(betterproto.Message): class ExtensionRangeOptions(betterproto.Message): """ """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1115,9 +1092,7 @@ class EnumDescriptorProto(betterproto.Message): """ - value: List["EnumValueDescriptorProto"] = betterproto.message_field( - 2, repeated=True - ) + value: List["EnumValueDescriptorProto"] = betterproto.message_field(2, repeated=True) """ """ @@ -1127,9 +1102,7 @@ class EnumDescriptorProto(betterproto.Message): """ - reserved_range: List["EnumDescriptorProtoEnumReservedRange"] = ( - betterproto.message_field(4, repeated=True) - ) + reserved_range: List["EnumDescriptorProtoEnumReservedRange"] = betterproto.message_field(4, repeated=True) """ Range of reserved numeric values. Reserved numeric values may not be used by enum values in the same enum declaration. Reserved ranges may not @@ -1412,9 +1385,7 @@ class FileOptions(betterproto.Message): determining the ruby package. """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """ The parser stores options it doesn't recognize here. See the documentation for the "Options" section above. @@ -1495,9 +1466,7 @@ class MessageOptions(betterproto.Message): parser. """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1505,9 +1474,7 @@ class MessageOptions(betterproto.Message): class FieldOptions(betterproto.Message): """ """ - ctype: "FieldOptionsCType" = betterproto.enum_field( - 1, enum_default_value=lambda: FieldOptionsCType.try_value(0) - ) + ctype: "FieldOptionsCType" = betterproto.enum_field(1, enum_default_value=lambda: FieldOptionsCType.try_value(0)) """ The ctype option instructs the C++ code generator to use a different representation of the field than it normally would. See the specific @@ -1524,9 +1491,7 @@ class FieldOptions(betterproto.Message): false will avoid using packed encoding. """ - jstype: "FieldOptionsJsType" = betterproto.enum_field( - 6, enum_default_value=lambda: FieldOptionsJsType.try_value(0) - ) + jstype: "FieldOptionsJsType" = betterproto.enum_field(6, enum_default_value=lambda: FieldOptionsJsType.try_value(0)) """ The jstype option determines the JavaScript type used for values of the field. The option is permitted only for 64 bit integral and fixed types @@ -1583,9 +1548,7 @@ class FieldOptions(betterproto.Message): weak: bool = betterproto.bool_field(10) """For Google-internal migration only. Do not use.""" - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1593,9 +1556,7 @@ class FieldOptions(betterproto.Message): class OneofOptions(betterproto.Message): """ """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1617,9 +1578,7 @@ class EnumOptions(betterproto.Message): is a formalization for deprecating enums. """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1635,9 +1594,7 @@ class EnumValueOptions(betterproto.Message): this is a formalization for deprecating enum values. """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1658,9 +1615,7 @@ class ServiceOptions(betterproto.Message): this is a formalization for deprecating services. """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1688,9 +1643,7 @@ class MethodOptions(betterproto.Message): """ - uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field( - 999, repeated=True - ) + uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999, repeated=True) """The parser stores options it doesn't recognize here. See above.""" @@ -1705,9 +1658,7 @@ class UninterpretedOption(betterproto.Message): in them. """ - name: List["UninterpretedOptionNamePart"] = betterproto.message_field( - 2, repeated=True - ) + name: List["UninterpretedOptionNamePart"] = betterproto.message_field(2, repeated=True) """ """ @@ -1775,9 +1726,7 @@ class SourceCodeInfo(betterproto.Message): FileDescriptorProto was generated. """ - location: List["SourceCodeInfoLocation"] = betterproto.message_field( - 1, repeated=True - ) + location: List["SourceCodeInfoLocation"] = betterproto.message_field(1, repeated=True) """ A Location identifies a piece of source code in a .proto file which corresponds to a particular definition. This information is intended @@ -1935,9 +1884,7 @@ class GeneratedCodeInfo(betterproto.Message): source file, but may contain references to different source .proto files. """ - annotation: List["GeneratedCodeInfoAnnotation"] = betterproto.message_field( - 1, repeated=True - ) + annotation: List["GeneratedCodeInfoAnnotation"] = betterproto.message_field(1, repeated=True) """ An Annotation connects some span of text in generated code to an element of its generating .proto file. @@ -2289,9 +2236,7 @@ class Struct(betterproto.Message): The JSON representation for `Struct` is JSON object. """ - fields: Dict[str, "Value"] = betterproto.map_field( - 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE - ) + fields: Dict[str, "Value"] = betterproto.map_field(1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE) """Unordered map of dynamically typed values.""" @hybridmethod @@ -2332,9 +2277,7 @@ class Value(betterproto.Message): The JSON representation for `Value` is JSON value. """ - null_value: "NullValue" = betterproto.enum_field( - 1, enum_default_value=lambda: NullValue.try_value(0), group="kind" - ) + null_value: "NullValue" = betterproto.enum_field(1, enum_default_value=lambda: NullValue.try_value(0), group="kind") """Represents a null value.""" number_value: float = betterproto.double_field(2, group="kind") diff --git a/src/betterproto/lib/std/google/protobuf/compiler/__init__.py b/src/betterproto/lib/std/google/protobuf/compiler/__init__.py index dc5334a3..3455a19a 100644 --- a/src/betterproto/lib/std/google/protobuf/compiler/__init__.py +++ b/src/betterproto/lib/std/google/protobuf/compiler/__init__.py @@ -46,8 +46,8 @@ class CodeGeneratorRequest(betterproto.Message): parameter: str = betterproto.string_field(2) """The generator parameter passed on the command-line.""" - proto_file: List["betterproto_lib_google_protobuf.FileDescriptorProto"] = ( - betterproto.message_field(15, repeated=True) + proto_file: List["betterproto_lib_google_protobuf.FileDescriptorProto"] = betterproto.message_field( + 15, repeated=True ) """ FileDescriptorProtos for all files in files_to_generate and everything @@ -71,9 +71,9 @@ class CodeGeneratorRequest(betterproto.Message): fully qualified. """ - source_file_descriptors: List[ - "betterproto_lib_google_protobuf.FileDescriptorProto" - ] = betterproto.message_field(17, repeated=True) + source_file_descriptors: List["betterproto_lib_google_protobuf.FileDescriptorProto"] = betterproto.message_field( + 17, repeated=True + ) """ File descriptors with all options, including source-retention options. These descriptors are only provided for the files listed in @@ -122,9 +122,7 @@ class CodeGeneratorResponse(betterproto.Message): effect for plugins that have FEATURE_SUPPORTS_EDITIONS set. """ - file: List["CodeGeneratorResponseFile"] = betterproto.message_field( - 15, repeated=True - ) + file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15, repeated=True) @dataclass(eq=False, repr=False) @@ -190,9 +188,7 @@ class CodeGeneratorResponseFile(betterproto.Message): content: str = betterproto.string_field(15) """The file contents.""" - generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = ( - betterproto.message_field(16) - ) + generated_code_info: "betterproto_lib_google_protobuf.GeneratedCodeInfo" = betterproto.message_field(16) """ Information describing the file content being inserted. If an insertion point is used, this information will be appropriately offset and inserted diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py index c28a133f..8c86109c 100644 --- a/src/betterproto/plugin/__init__.py +++ b/src/betterproto/plugin/__init__.py @@ -1 +1,3 @@ +__all__ = ["main"] + from .main import main diff --git a/src/betterproto/plugin/__main__.py b/src/betterproto/plugin/__main__.py index bd95daea..5d6a8109 100644 --- a/src/betterproto/plugin/__main__.py +++ b/src/betterproto/plugin/__main__.py @@ -1,4 +1,3 @@ from .main import main - main() diff --git a/src/betterproto/plugin/compiler.py b/src/betterproto/plugin/compiler.py index e91a1c45..e8d261c1 100644 --- a/src/betterproto/plugin/compiler.py +++ b/src/betterproto/plugin/compiler.py @@ -4,7 +4,6 @@ from .module_validation import ModuleValidator - try: # betterproto[compiler] specific dependencies import jinja2 @@ -23,9 +22,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: - templates_folder = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "templates") - ) + templates_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "templates")) env = jinja2.Environment( trim_blocks=True, @@ -48,9 +45,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: ) # Format the code - code = subprocess.check_output( - ["ruff", "format", "-"], input=code, encoding="utf-8" - ) + code = subprocess.check_output(["ruff", "format", "-"], input=code, encoding="utf-8") # Validate the generated code. validator = ModuleValidator(iter(code.splitlines())) diff --git a/src/betterproto/plugin/main.py b/src/betterproto/plugin/main.py index 29079204..c9dcc617 100755 --- a/src/betterproto/plugin/main.py +++ b/src/betterproto/plugin/main.py @@ -5,7 +5,6 @@ from betterproto.lib.google.protobuf.compiler import ( CodeGeneratorRequest, - CodeGeneratorResponse, ) # from betterproto.plugin.models import monkey_patch_oneof_index diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 5e824c58..1dfbc789 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -55,7 +55,6 @@ from betterproto.lib.google.protobuf import ( DescriptorProto, EnumDescriptorProto, - Field, FieldDescriptorProto, FieldDescriptorProtoLabel, FieldDescriptorProtoType, @@ -76,7 +75,6 @@ TypingCompiler, ) - # Create a unique placeholder to deal with # https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses PLACEHOLDER = object() @@ -164,9 +162,7 @@ def get_comment( lines.append("") # Remove consecutive empty lines - lines = [ - line for i, line in enumerate(lines) if line or (i == 0 or lines[i - 1]) - ] + lines = [line for i, line in enumerate(lines) if line or (i == 0 or lines[i - 1])] if lines and not lines[-1]: lines.pop() # Remove the last empty line @@ -239,11 +235,7 @@ def all_messages(self) -> List["MessageCompiler"]: List[MessageCompiler] List of all of the messages in this request. """ - return [ - msg - for output in self.output_packages.values() - for msg in output.messages.values() - ] + return [msg for output in self.output_packages.values() for msg in output.messages.values()] @dataclass @@ -298,9 +290,7 @@ class MessageCompiler(ProtoContentBase): parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER - fields: List[Union["FieldCompiler", "MessageCompiler"]] = field( - default_factory=list - ) + fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list) builtins_types: Set[str] = field(default_factory=set) def __post_init__(self) -> None: @@ -345,9 +335,7 @@ def has_message_field(self) -> bool: ) -def is_map( - proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto -) -> bool: +def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto) -> bool: """True if proto_field_obj is a map, otherwise False.""" if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE: if not hasattr(parent_message, "nested_type"): @@ -358,10 +346,7 @@ def is_map( map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" if message_type == map_entry: for nested in parent_message.nested_type: # parent message - if ( - nested.name.replace("_", "").lower() == map_entry - and nested.options.map_entry - ): + if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry: return True return False @@ -381,9 +366,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: us to tell whether it was set, via the which_one_of interface. """ - return ( - not proto_field_obj.proto3_optional and proto_field_obj.oneof_index is not None - ) + return not proto_field_obj.proto3_optional and proto_field_obj.oneof_index is not None @dataclass @@ -405,12 +388,8 @@ def __post_init__(self) -> None: def get_field_string(self) -> str: """Construct string representation of this field as a field.""" name = f"{self.py_name}" - field_args = ", ".join( - ([""] + self.betterproto_field_args) if self.betterproto_field_args else [] - ) - betterproto_field_type = ( - f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})" - ) + field_args = ", ".join(([""] + self.betterproto_field_args) if self.betterproto_field_args else []) + betterproto_field_type = f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})" if self.py_name in dir(builtins): self.parent.builtins_types.add(self.py_name) return f'{name}: "{self.annotation}" = {betterproto_field_type}' @@ -438,9 +417,7 @@ def use_builtins(self) -> bool: @property def field_wraps(self) -> Optional[str]: """Returns betterproto wrapped field type or None.""" - match_wrapper = re.match( - r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name - ) + match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name) if match_wrapper: wrapped_type = "TYPE_" + match_wrapper.group(1).upper() if hasattr(betterproto, wrapped_type): @@ -449,25 +426,18 @@ def field_wraps(self) -> Optional[str]: @property def repeated(self) -> bool: - return ( - self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED - and not is_map(self.proto_obj, self.parent) + return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED and not is_map( + self.proto_obj, self.parent ) @property def optional(self) -> bool: - return self.proto_obj.proto3_optional or ( - self.field_type == "message" and not self.repeated - ) + return self.proto_obj.proto3_optional or (self.field_type == "message" and not self.repeated) @property def field_type(self) -> str: """String representation of proto field type.""" - return ( - FieldDescriptorProtoType(self.proto_obj.type) - .name.lower() - .replace("type_", "") - ) + return FieldDescriptorProtoType(self.proto_obj.type).name.lower().replace("type_", "") @property def packed(self) -> bool: @@ -546,10 +516,7 @@ class MapEntryCompiler(FieldCompiler): def __post_init__(self): map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" for nested in self.parent.proto_obj.nested_type: - if ( - nested.name.replace("_", "").lower() == map_entry - and nested.options.map_entry - ): + if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry: pass return super().__post_init__() @@ -557,10 +524,7 @@ def ready(self) -> None: """Explore nested types and set k_type and v_type if unset.""" map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" for nested in self.parent.proto_obj.nested_type: - if ( - nested.name.replace("_", "").lower() == map_entry - and nested.options.map_entry - ): + if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry: # Get Python types self.py_k_type = FieldCompiler( source_file=self.source_file, @@ -618,13 +582,9 @@ def __post_init__(self) -> None: # Get entries/allowed values for this Enum self.entries = [ self.EnumEntry( - name=pythonize_enum_member_name( - entry_proto_value.name, self.proto_obj.name - ), + name=pythonize_enum_member_name(entry_proto_value.name, self.proto_obj.name), value=entry_proto_value.number, - comment=get_comment( - proto_file=self.source_file, path=self.path + [2, entry_number] - ), + comment=get_comment(proto_file=self.source_file, path=self.path + [2, entry_number]), ) for entry_number, entry_proto_value in enumerate(self.proto_obj.value) ] @@ -678,9 +638,7 @@ def proto_name(self) -> str: @property def route(self) -> str: - package_part = ( - f"{self.output_file.package}." if self.output_file.package else "" - ) + package_part = f"{self.output_file.package}." if self.output_file.package else "" return f"/{package_part}{self.parent.proto_name}/{self.proto_name}" @property diff --git a/src/betterproto/plugin/module_validation.py b/src/betterproto/plugin/module_validation.py index 4cf05fdc..19a06a96 100644 --- a/src/betterproto/plugin/module_validation.py +++ b/src/betterproto/plugin/module_validation.py @@ -17,9 +17,7 @@ class ModuleValidator: line_iterator: Iterator[str] line_number: int = field(init=False, default=0) - collisions: Dict[str, List[Tuple[int, str]]] = field( - init=False, default_factory=lambda: defaultdict(list) - ) + collisions: Dict[str, List[Tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list)) def add_import(self, imp: str, number: int, full_line: str): """ diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 19f17521..6c1f7ba1 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -45,29 +45,19 @@ def traverse( proto_file: FileDescriptorProto, -) -> Generator[ - Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None -]: +) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]: # Todo: Keep information about nested hierarchy def _traverse( path: List[int], items: Union[List[EnumDescriptorProto], List[DescriptorProto]], prefix: str = "", - ) -> Generator[ - Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None - ]: + ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple - should_rename = ( - not isinstance(item, DescriptorProto) - or not item.options - or not item.options.map_entry - ) + should_rename = not isinstance(item, DescriptorProto) or not item.options or not item.options.map_entry - item.name = next_prefix = ( - f"{prefix}.{item.name}" if prefix and should_rename else item.name - ) + item.name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name yield item, [*path, i] if isinstance(item, DescriptorProto): @@ -97,40 +87,27 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # Add this input file to the output corresponding to this package request_data.output_packages[output_package_name].input_files.append(proto_file) - if ( - proto_file.package == "google.protobuf" - and "INCLUDE_GOOGLE" not in plugin_options - ): + if proto_file.package == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options: # If not INCLUDE_GOOGLE, # skip outputting Google's well-known types request_data.output_packages[output_package_name].output = False if "pydantic_dataclasses" in plugin_options: - request_data.output_packages[ - output_package_name - ].pydantic_dataclasses = True + request_data.output_packages[output_package_name].pydantic_dataclasses = True # Gather any typing generation options. - typing_opts = [ - opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.") - ] + typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")] if len(typing_opts) > 1: raise ValueError("Multiple typing options provided") # Set the compiler type. typing_opt = typing_opts[0] if typing_opts else "direct" if typing_opt == "direct": - request_data.output_packages[ - output_package_name - ].typing_compiler = DirectImportTypingCompiler() + request_data.output_packages[output_package_name].typing_compiler = DirectImportTypingCompiler() elif typing_opt == "root": - request_data.output_packages[ - output_package_name - ].typing_compiler = TypingImportTypingCompiler() + request_data.output_packages[output_package_name].typing_compiler = TypingImportTypingCompiler() elif typing_opt == "310": - request_data.output_packages[ - output_package_name - ].typing_compiler = NoTyping310TypingCompiler() + request_data.output_packages[output_package_name].typing_compiler = NoTyping310TypingCompiler() # Read Messages and Enums # We need to read Messages before Services in so that we can @@ -245,9 +222,7 @@ def read_protobuf_type( typing_compiler=output_package.typing_compiler, ) elif is_oneof(field): - _make_one_of_field_compiler( - output_package, source_file, message_data, field, path + [2, index] - ) + _make_one_of_field_compiler(output_package, source_file, message_data, field, path + [2, index]) else: FieldCompiler( source_file=source_file, diff --git a/src/betterproto/plugin/typing_compiler.py b/src/betterproto/plugin/typing_compiler.py index c77f38bb..aa4f2135 100644 --- a/src/betterproto/plugin/typing_compiler.py +++ b/src/betterproto/plugin/typing_compiler.py @@ -14,32 +14,32 @@ class TypingCompiler(metaclass=abc.ABCMeta): @abc.abstractmethod - def optional(self, type: str) -> str: - raise NotImplementedError() + def optional(self, type_: str) -> str: + raise NotImplementedError @abc.abstractmethod - def list(self, type: str) -> str: - raise NotImplementedError() + def list(self, type_: str) -> str: + raise NotImplementedError @abc.abstractmethod def dict(self, key: str, value: str) -> str: - raise NotImplementedError() + raise NotImplementedError @abc.abstractmethod def union(self, *types: str) -> str: - raise NotImplementedError() + raise NotImplementedError @abc.abstractmethod - def iterable(self, type: str) -> str: - raise NotImplementedError() + def iterable(self, type_: str) -> str: + raise NotImplementedError @abc.abstractmethod - def async_iterable(self, type: str) -> str: - raise NotImplementedError() + def async_iterable(self, type_: str) -> str: + raise NotImplementedError @abc.abstractmethod - def async_iterator(self, type: str) -> str: - raise NotImplementedError() + def async_iterator(self, type_: str) -> str: + raise NotImplementedError @abc.abstractmethod def imports(self) -> Dict[str, Optional[Set[str]]]: @@ -47,7 +47,7 @@ def imports(self) -> Dict[str, Optional[Set[str]]]: Returns either the direct import as a key with none as value, or a set of values to import from the key. """ - raise NotImplementedError() + raise NotImplementedError def import_lines(self) -> Iterator: imports = self.imports() @@ -65,13 +65,13 @@ def import_lines(self) -> Iterator: class DirectImportTypingCompiler(TypingCompiler): _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) - def optional(self, type: str) -> str: + def optional(self, type_: str) -> str: self._imports["typing"].add("Optional") - return f"Optional[{type}]" + return f"Optional[{type_}]" - def list(self, type: str) -> str: + def list(self, type_: str) -> str: self._imports["typing"].add("List") - return f"List[{type}]" + return f"List[{type_}]" def dict(self, key: str, value: str) -> str: self._imports["typing"].add("Dict") @@ -81,17 +81,17 @@ def union(self, *types: str) -> str: self._imports["typing"].add("Union") return f"Union[{', '.join(types)}]" - def iterable(self, type: str) -> str: + def iterable(self, type_: str) -> str: self._imports["typing"].add("Iterable") - return f"Iterable[{type}]" + return f"Iterable[{type_}]" - def async_iterable(self, type: str) -> str: + def async_iterable(self, type_: str) -> str: self._imports["typing"].add("AsyncIterable") - return f"AsyncIterable[{type}]" + return f"AsyncIterable[{type_}]" - def async_iterator(self, type: str) -> str: + def async_iterator(self, type_: str) -> str: self._imports["typing"].add("AsyncIterator") - return f"AsyncIterator[{type}]" + return f"AsyncIterator[{type_}]" def imports(self) -> Dict[str, Optional[Set[str]]]: return {k: v if v else None for k, v in self._imports.items()} @@ -101,13 +101,13 @@ def imports(self) -> Dict[str, Optional[Set[str]]]: class TypingImportTypingCompiler(TypingCompiler): _imported: bool = False - def optional(self, type: str) -> str: + def optional(self, type_: str) -> str: self._imported = True - return f"typing.Optional[{type}]" + return f"typing.Optional[{type_}]" - def list(self, type: str) -> str: + def list(self, type_: str) -> str: self._imported = True - return f"typing.List[{type}]" + return f"typing.List[{type_}]" def dict(self, key: str, value: str) -> str: self._imported = True @@ -117,17 +117,17 @@ def union(self, *types: str) -> str: self._imported = True return f"typing.Union[{', '.join(types)}]" - def iterable(self, type: str) -> str: + def iterable(self, type_: str) -> str: self._imported = True - return f"typing.Iterable[{type}]" + return f"typing.Iterable[{type_}]" - def async_iterable(self, type: str) -> str: + def async_iterable(self, type_: str) -> str: self._imported = True - return f"typing.AsyncIterable[{type}]" + return f"typing.AsyncIterable[{type_}]" - def async_iterator(self, type: str) -> str: + def async_iterator(self, type_: str) -> str: self._imported = True - return f"typing.AsyncIterator[{type}]" + return f"typing.AsyncIterator[{type_}]" def imports(self) -> Dict[str, Optional[Set[str]]]: if self._imported: @@ -139,11 +139,11 @@ def imports(self) -> Dict[str, Optional[Set[str]]]: class NoTyping310TypingCompiler(TypingCompiler): _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) - def optional(self, type: str) -> str: - return f"{type} | None" + def optional(self, type_: str) -> str: + return f"{type_} | None" - def list(self, type: str) -> str: - return f"list[{type}]" + def list(self, type_: str) -> str: + return f"list[{type_}]" def dict(self, key: str, value: str) -> str: return f"dict[{key}, {value}]" @@ -151,17 +151,17 @@ def dict(self, key: str, value: str) -> str: def union(self, *types: str) -> str: return f"{' | '.join(types)}" - def iterable(self, type: str) -> str: + def iterable(self, type_: str) -> str: self._imports["collections.abc"].add("Iterable") - return f"Iterable[{type}]" + return f"Iterable[{type_}]" - def async_iterable(self, type: str) -> str: + def async_iterable(self, type_: str) -> str: self._imports["collections.abc"].add("AsyncIterable") - return f"AsyncIterable[{type}]" + return f"AsyncIterable[{type_}]" - def async_iterator(self, type: str) -> str: + def async_iterator(self, type_: str) -> str: self._imports["collections.abc"].add("AsyncIterator") - return f"AsyncIterator[{type}]" + return f"AsyncIterator[{type_}]" def imports(self) -> Dict[str, Optional[Set[str]]]: return {k: v if v else None for k, v in self._imports.items()} diff --git a/src/betterproto/utils.py b/src/betterproto/utils.py index b977fc71..3603b7ed 100644 --- a/src/betterproto/utils.py +++ b/src/betterproto/utils.py @@ -15,7 +15,6 @@ Self, ) - SelfT = TypeVar("SelfT") P = ParamSpec("P") HybridT = TypeVar("HybridT", covariant=True) @@ -24,9 +23,7 @@ class hybridmethod(Generic[SelfT, P, HybridT]): def __init__( self, - func: Callable[ - Concatenate[type[SelfT], P], HybridT - ], # Must be the classmethod version + func: Callable[Concatenate[type[SelfT], P], HybridT], # Must be the classmethod version ): self.cls_func = func self.__doc__ = func.__doc__ @@ -35,9 +32,7 @@ def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self self.instance_func = func return self - def __get__( - self, instance: Optional[SelfT], owner: Type[SelfT] - ) -> Callable[P, HybridT]: + def __get__(self, instance: Optional[SelfT], owner: Type[SelfT]) -> Callable[P, HybridT]: if instance is None or self.instance_func is None: # either bound to the class, or no instance method available return self.cls_func.__get__(owner, None) diff --git a/tests/generate.py b/tests/generate.py index 2a2b07a1..67dad859 100755 --- a/tests/generate.py +++ b/tests/generate.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import asyncio import os -import platform import shutil import sys from pathlib import Path @@ -16,7 +15,6 @@ protoc, ) - # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -44,37 +42,25 @@ async def generate(whitelist: Set[str], verbose: bool): generation_tasks = [] for test_case_name in sorted(test_case_names): test_case_input_path = inputs_path.joinpath(test_case_name).resolve() - if ( - whitelist - and str(test_case_input_path) not in path_whitelist - and test_case_name not in name_whitelist - ): + if whitelist and str(test_case_input_path) not in path_whitelist and test_case_name not in name_whitelist: continue - generation_tasks.append( - generate_test_case_output(test_case_input_path, test_case_name, verbose) - ) + generation_tasks.append(generate_test_case_output(test_case_input_path, test_case_name, verbose)) failed_test_cases = [] # Wait for all subprocs and match any failures to names to report - for test_case_name, result in zip( - sorted(test_case_names), await asyncio.gather(*generation_tasks) - ): + for test_case_name, result in zip(sorted(test_case_names), await asyncio.gather(*generation_tasks)): if result != 0: failed_test_cases.append(test_case_name) if len(failed_test_cases) > 0: - sys.stderr.write( - "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" - ) + sys.stderr.write("\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n") for failed_test_case in failed_test_cases: sys.stderr.write(f"- {failed_test_case}\n") sys.exit(1) -async def generate_test_case_output( - test_case_input_path: Path, test_case_name: str, verbose: bool -) -> int: +async def generate_test_case_output(test_case_input_path: Path, test_case_name: str, verbose: bool) -> int: """ Returns the max of the subprocess return values """ @@ -97,17 +83,13 @@ async def generate_test_case_output( ) = await asyncio.gather( protoc(test_case_input_path, test_case_output_path_reference, True), protoc(test_case_input_path, test_case_output_path_betterproto, False), - protoc( - test_case_input_path, test_case_output_path_betterproto_pyd, False, True - ), + protoc(test_case_input_path, test_case_output_path_betterproto_pyd, False, True), ) if ref_code == 0: print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m") else: - print( - f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m" - ) + print(f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m") print(ref_err.decode()) if verbose: @@ -124,9 +106,7 @@ async def generate_test_case_output( if plg_code == 0: print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m") else: - print( - f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m" - ) + print(f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m") print(plg_err.decode()) if verbose: @@ -141,13 +121,9 @@ async def generate_test_case_output( sys.stderr.buffer.flush() if plg_code_pyd == 0: - print( - f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m" - ) + print(f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m") else: - print( - f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" - ) + print(f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m") print(plg_err_pyd.decode()) if verbose: @@ -169,7 +145,7 @@ async def generate_test_case_output( "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]", "Generate python classes for standard tests.", "", - "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", + "DIRECTORIES One or more relative or absolute directories of test-cases to generate" "classes for.", " python generate.py inputs/bool inputs/double inputs/enum", "", "NAMES One or more test-case names to generate classes for.", diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py index 96cd158e..2b6a7ad5 100644 --- a/tests/grpc/test_grpclib_client.py +++ b/tests/grpc/test_grpclib_client.py @@ -53,9 +53,7 @@ async def test_simple_service_call(): @pytest.mark.asyncio -async def test_trailer_only_error_unary_unary( - mocker, handler_trailer_only_unauthenticated -): +async def test_trailer_only_error_unary_unary(mocker, handler_trailer_only_unauthenticated): service = ThingService() mocker.patch.object( service, @@ -70,9 +68,7 @@ async def test_trailer_only_error_unary_unary( @pytest.mark.asyncio -async def test_trailer_only_error_stream_unary( - mocker, handler_trailer_only_unauthenticated -): +async def test_trailer_only_error_stream_unary(mocker, handler_trailer_only_unauthenticated): service = ThingService() mocker.patch.object( service, @@ -105,23 +101,15 @@ async def test_service_call_with_upfront_request_params(): # Setting deadline deadline = grpclib.metadata.Deadline.from_timeout(22) metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] - ) as channel: - await _test_client( - ThingServiceClient(channel, deadline=deadline, metadata=metadata) - ) + async with ChannelFor([ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]) as channel: + await _test_client(ThingServiceClient(channel, deadline=deadline, metadata=metadata)) # Setting timeout timeout = 99 deadline = grpclib.metadata.Deadline.from_timeout(timeout) metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] - ) as channel: - await _test_client( - ThingServiceClient(channel, timeout=timeout, metadata=metadata) - ) + async with ChannelFor([ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]) as channel: + await _test_client(ThingServiceClient(channel, timeout=timeout, metadata=metadata)) @pytest.mark.asyncio @@ -133,9 +121,7 @@ async def test_service_call_lower_level_with_overrides(): metadata = {"authorization": "12345"} kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) kwarg_metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] - ) as channel: + async with ChannelFor([ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]) as channel: client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) response = await client._unary_unary( "/service.Test/DoThing", @@ -195,9 +181,7 @@ async def test_service_call_high_level_with_overrides(mocker, overrides_gen): [ ThingService( test_hook=_assert_request_meta_received( - deadline=grpclib.metadata.Deadline.from_timeout( - overrides.get("timeout", 99) - ), + deadline=grpclib.metadata.Deadline.from_timeout(overrides.get("timeout", 99)), metadata=overrides.get("metadata", defaults.get("metadata")), ) ) @@ -227,9 +211,7 @@ async def test_async_gen_for_unary_stream_request(): async with ChannelFor([ThingService()]) as channel: client = ThingServiceClient(channel) expected_versions = [5, 4, 3, 2, 1] - async for response in client.get_thing_versions( - GetThingRequest(name=thing_name) - ): + async for response in client.get_thing_versions(GetThingRequest(name=thing_name)): assert response.name == thing_name assert response.version == expected_versions.pop() @@ -264,9 +246,7 @@ async def test_async_gen_for_stream_stream_request(): else: # No more things to send make sure channel is closed request_chan.close() - assert response_index == len( - expected_things - ), "Didn't receive all expected responses" + assert response_index == len(expected_things), "Didn't receive all expected responses" @pytest.mark.asyncio @@ -287,7 +267,5 @@ async def test_stream_stream_with_empty_iterable(): async with ChannelFor([ThingService()]) as channel: client = ThingServiceClient(channel) requests = [GetThingRequest(name) for name in things] - responses = [ - response async for response in client.get_different_things(requests) - ] + responses = [response async for response in client.get_different_things(requests)] assert len(responses) == 0 diff --git a/tests/grpc/test_stream_stream.py b/tests/grpc/test_stream_stream.py index 9a1e5b89..9ce95b5f 100644 --- a/tests/grpc/test_stream_stream.py +++ b/tests/grpc/test_stream_stream.py @@ -40,27 +40,19 @@ def client(): @pytest.mark.asyncio -async def test_send_from_before_connect_and_close_automatically( - client, expected_responses -): +async def test_send_from_before_connect_and_close_automatically(client, expected_responses): requests = AsyncChannel() - await requests.send_from( - [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True - ) + await requests.send_from([Message(body="Hello world 1"), Message(body="Hello world 2")], close=True) responses = client.connect(requests) assert await to_list(responses) == expected_responses @pytest.mark.asyncio -async def test_send_from_after_connect_and_close_automatically( - client, expected_responses -): +async def test_send_from_after_connect_and_close_automatically(client, expected_responses): requests = AsyncChannel() responses = client.connect(requests) - await requests.send_from( - [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True - ) + await requests.send_from([Message(body="Hello world 1"), Message(body="Hello world 2")], close=True) assert await to_list(responses) == expected_responses @@ -69,9 +61,7 @@ async def test_send_from_after_connect_and_close_automatically( async def test_send_from_close_manually_immediately(client, expected_responses): requests = AsyncChannel() responses = client.connect(requests) - await requests.send_from( - [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False - ) + await requests.send_from([Message(body="Hello world 1"), Message(body="Hello world 2")], close=False) requests.close() assert await to_list(responses) == expected_responses diff --git a/tests/grpc/thing_service.py b/tests/grpc/thing_service.py index 7723a29f..8693e628 100644 --- a/tests/grpc/thing_service.py +++ b/tests/grpc/thing_service.py @@ -16,45 +16,33 @@ def __init__(self, test_hook=None): # This lets us pass assertions to the servicer ;) self.test_hook = test_hook - async def do_thing( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): + async def do_thing(self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"): request = await stream.recv_message() if self.test_hook is not None: self.test_hook(stream) await stream.send_message(DoThingResponse([request.name])) - async def do_many_things( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): + async def do_many_things(self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"): thing_names = [request.name async for request in stream] if self.test_hook is not None: self.test_hook(stream) await stream.send_message(DoThingResponse(thing_names)) - async def get_thing_versions( - self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" - ): + async def get_thing_versions(self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"): request = await stream.recv_message() if self.test_hook is not None: self.test_hook(stream) for version_num in range(1, 6): - await stream.send_message( - GetThingResponse(name=request.name, version=version_num) - ) + await stream.send_message(GetThingResponse(name=request.name, version=version_num)) - async def get_different_things( - self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" - ): + async def get_different_things(self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"): if self.test_hook is not None: self.test_hook(stream) # Respond to each input item immediately response_num = 0 async for request in stream: response_num += 1 - await stream.send_message( - GetThingResponse(name=request.name, version=response_num) - ) + await stream.send_message(GetThingResponse(name=request.name, version=response_num)) def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: return { diff --git a/tests/inputs/casing/test_casing.py b/tests/inputs/casing/test_casing.py index 9ca42435..feee009a 100644 --- a/tests/inputs/casing/test_casing.py +++ b/tests/inputs/casing/test_casing.py @@ -4,20 +4,14 @@ def test_message_attributes(): message = Test() - assert hasattr( - message, "snake_case_message" - ), "snake_case field name is same in python" + assert hasattr(message, "snake_case_message"), "snake_case field name is same in python" assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python" assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python" def test_message_casing(): - assert hasattr( - casing, "SnakeCaseMessage" - ), "snake_case Message name is converted to CamelCase in python" + assert hasattr(casing, "SnakeCaseMessage"), "snake_case Message name is converted to CamelCase in python" def test_enum_casing(): - assert hasattr( - casing, "MyEnum" - ), "snake_case Enum name is converted to CamelCase in python" + assert hasattr(casing, "MyEnum"), "snake_case Enum name is converted to CamelCase in python" diff --git a/tests/inputs/casing_inner_class/test_casing_inner_class.py b/tests/inputs/casing_inner_class/test_casing_inner_class.py index 50103c92..2560b6c2 100644 --- a/tests/inputs/casing_inner_class/test_casing_inner_class.py +++ b/tests/inputs/casing_inner_class/test_casing_inner_class.py @@ -2,13 +2,9 @@ def test_message_casing_inner_class_name(): - assert hasattr( - casing_inner_class, "TestInnerClass" - ), "Inline defined Message is correctly converted to CamelCase" + assert hasattr(casing_inner_class, "TestInnerClass"), "Inline defined Message is correctly converted to CamelCase" def test_message_casing_inner_class_attributes(): message = casing_inner_class.Test(inner=casing_inner_class.TestInnerClass()) - assert hasattr( - message.inner, "old_exp" - ), "Inline defined Message attribute is snake_case" + assert hasattr(message.inner, "old_exp"), "Inline defined Message attribute is snake_case" diff --git a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py index 2b32b530..6dc69256 100644 --- a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py +++ b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py @@ -3,12 +3,6 @@ def test_message_casing(): message = Test() - assert hasattr( - message, "uppercase" - ), "UPPERCASE attribute is converted to 'uppercase' in python" - assert hasattr( - message, "uppercase_v2" - ), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python" - assert hasattr( - message, "upper_camel_case" - ), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python" + assert hasattr(message, "uppercase"), "UPPERCASE attribute is converted to 'uppercase' in python" + assert hasattr(message, "uppercase_v2"), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python" + assert hasattr(message, "upper_camel_case"), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python" diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py index 21a5ac3b..20c9a4d5 100644 --- a/tests/inputs/enum/test_enum.py +++ b/tests/inputs/enum/test_enum.py @@ -27,13 +27,8 @@ def test_enum_is_comparable_with_int(): def test_enum_to_dict(): - assert ( - "choice" not in Test(choice=Choice.ZERO).to_dict() - ), "Default enum value is not serialized" - assert ( - Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"] - == "ZERO" - ) + assert "choice" not in Test(choice=Choice.ZERO).to_dict(), "Default enum value is not serialized" + assert Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"] == "ZERO" assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE" assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE" assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR" @@ -59,9 +54,7 @@ def test_repeated_enum_to_dict(): assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"] assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"] - all_enums_dict = Test( - choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR] - ).to_dict() + all_enums_dict = Test(choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]).to_dict() assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"] diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py index 23b2e3b4..cd2cc40f 100644 --- a/tests/inputs/example_service/test_example_service.py +++ b/tests/inputs/example_service/test_example_service.py @@ -1,5 +1,4 @@ from typing import ( - AsyncIterable, AsyncIterator, ) @@ -15,17 +14,13 @@ class ExampleService(TestBase): - async def example_unary_unary( - self, example_request: ExampleRequest - ) -> "ExampleResponse": + async def example_unary_unary(self, example_request: ExampleRequest) -> "ExampleResponse": return ExampleResponse( example_string=example_request.example_string, example_integer=example_request.example_integer, ) - async def example_unary_stream( - self, example_request: ExampleRequest - ) -> AsyncIterator["ExampleResponse"]: + async def example_unary_stream(self, example_request: ExampleRequest) -> AsyncIterator["ExampleResponse"]: response = ExampleResponse( example_string=example_request.example_string, example_integer=example_request.example_integer, diff --git a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py index 97587cfa..b6ed5e0f 100644 --- a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py +++ b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py @@ -7,7 +7,6 @@ from google.protobuf import json_format from google.protobuf.timestamp_pb2 import Timestamp -import betterproto from tests.output_betterproto.google_impl_behavior_equivalence import ( Empty, Foo, @@ -72,10 +71,7 @@ def test_datetime_clamping(dt): # see #407 assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString() message_bytes = bytes(Spam(dt)) - assert ( - Spam().parse(message_bytes).ts.timestamp() - == ReferenceSpam.FromString(message_bytes).ts.seconds - ) + assert Spam().parse(message_bytes).ts.timestamp() == ReferenceSpam.FromString(message_bytes).ts.seconds def test_empty_message_field(): diff --git a/tests/inputs/googletypes_request/test_googletypes_request.py b/tests/inputs/googletypes_request/test_googletypes_request.py index ffb2608f..f1cd4f0b 100644 --- a/tests/inputs/googletypes_request/test_googletypes_request.py +++ b/tests/inputs/googletypes_request/test_googletypes_request.py @@ -16,7 +16,6 @@ TestStub, ) - test_cases = [ (TestStub.send_double, protobuf.DoubleValue, 2.5), (TestStub.send_float, protobuf.FloatValue, 2.5), diff --git a/tests/inputs/googletypes_response/test_googletypes_response.py b/tests/inputs/googletypes_response/test_googletypes_response.py index 6e1ed29c..e1aebc6d 100644 --- a/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/tests/inputs/googletypes_response/test_googletypes_response.py @@ -13,7 +13,6 @@ TestStub, ) - test_cases = [ (TestStub.get_double, protobuf.DoubleValue, 2.5), (TestStub.get_float, protobuf.FloatValue, 2.5), diff --git a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py index 29dd8163..d2ff494e 100644 --- a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py +++ b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py @@ -1,7 +1,5 @@ import datetime -import pytest - import betterproto from tests.output_betterproto.oneof_default_value_serialization import ( Message, @@ -60,11 +58,7 @@ def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): """ Nested messages with oneofs should also be handled """ - message = Test( - wrapped_nested_message_value=NestedMessage( - id=0, wrapped_message_value=Message(value=0) - ) - ) + message = Test(wrapped_nested_message_value=NestedMessage(id=0, wrapped_message_value=Message(value=0))) assert ( betterproto.which_one_of(message, "value_type") == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index 8dda1a97..4a71223b 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -12,9 +12,7 @@ def test_which_one_of_returns_enum_with_default_value(): returns first field when it is enum and set with default value """ message = Test() - message.from_json( - get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json - ) + message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json) assert message.move is None assert message.signal == Signal.PASS @@ -26,9 +24,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): returns first field when it is enum and set with non default value """ message = Test() - message.from_json( - get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json - ) + message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json) assert message.move is None assert message.signal == Signal.RESIGN diff --git a/tests/inputs/proto3_field_presence/test_proto3_field_presence.py b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py index e3111998..9c2d6e69 100644 --- a/tests/inputs/proto3_field_presence/test_proto3_field_presence.py +++ b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py @@ -1,9 +1,7 @@ import json from tests.output_betterproto.proto3_field_presence import ( - InnerTest, Test, - TestEnum, ) diff --git a/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py index 6008e6a4..2320dc64 100644 --- a/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py +++ b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py @@ -1,5 +1,4 @@ from tests.output_betterproto.proto3_field_presence_oneof import ( - InnerNested, Nested, Test, WithOptional, diff --git a/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py b/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py index a4438488..35783ea6 100644 --- a/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py +++ b/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py @@ -8,16 +8,12 @@ from tests.output_betterproto.timestamp_dict_encode import Test - # Current World Timezone range (UTC-12 to UTC+14) MIN_UTC_OFFSET_MIN = -12 * 60 MAX_UTC_OFFSET_MIN = 14 * 60 # Generate all timezones in range in 15 min increments -timezones = [ - timezone(timedelta(minutes=x)) - for x in range(MIN_UTC_OFFSET_MIN, MAX_UTC_OFFSET_MIN + 1, 15) -] +timezones = [timezone(timedelta(minutes=x)) for x in range(MIN_UTC_OFFSET_MIN, MAX_UTC_OFFSET_MIN + 1, 15)] @pytest.mark.parametrize("tz", timezones) diff --git a/tests/test_features.py b/tests/test_features.py index 84a4f199..f8297e41 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,9 +1,4 @@ import json -import sys -from copy import ( - copy, - deepcopy, -) from dataclasses import dataclass from datetime import ( datetime, @@ -14,14 +9,11 @@ signature, ) from typing import ( - Dict, List, Optional, ) from unittest.mock import ANY -import pytest - import betterproto @@ -48,9 +40,7 @@ class TestEnum(betterproto.Enum): @dataclass class Foo(betterproto.Message): - bar: TestEnum = betterproto.enum_field( - 1, enum_default_value=lambda: TestEnum.try_value(0) - ) + bar: TestEnum = betterproto.enum_field(1, enum_default_value=lambda: TestEnum.try_value(0)) # JSON strings are supported, but ints should still be supported too. foo = Foo().from_dict({"bar": 1}) @@ -142,9 +132,7 @@ class CasingTest(betterproto.Message): kabob_case: int = betterproto.int32_field(4) # Parsing should accept almost any input - test = CasingTest().from_dict( - {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4} - ) + test = CasingTest().from_dict({"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}) assert test == CasingTest(1, 2, 3, 4) @@ -173,9 +161,7 @@ class CasingTest(betterproto.Message): kabob_case: int = betterproto.int32_field(4) # Parsing should accept almost any input - test = CasingTest().from_dict( - {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4} - ) + test = CasingTest().from_dict({"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}) assert test == CasingTest(1, 2, 3, 4) @@ -230,22 +216,14 @@ class Request(betterproto.Message): # Check dict serialization assert Request().to_dict() == {} assert Request().to_dict(include_default_values=True) == {"date": None} - assert Request(date=datetime(2020, 1, 1)).to_dict() == { - "date": "2020-01-01T00:00:00Z" - } - assert Request(date=datetime(2020, 1, 1)).to_dict(include_default_values=True) == { - "date": "2020-01-01T00:00:00Z" - } + assert Request(date=datetime(2020, 1, 1)).to_dict() == {"date": "2020-01-01T00:00:00Z"} + assert Request(date=datetime(2020, 1, 1)).to_dict(include_default_values=True) == {"date": "2020-01-01T00:00:00Z"} # Check pydict serialization assert Request().to_pydict() == {} assert Request().to_pydict(include_default_values=True) == {"date": None} - assert Request(date=datetime(2020, 1, 1)).to_pydict() == { - "date": datetime(2020, 1, 1) - } - assert Request(date=datetime(2020, 1, 1)).to_pydict( - include_default_values=True - ) == {"date": datetime(2020, 1, 1)} + assert Request(date=datetime(2020, 1, 1)).to_pydict() == {"date": datetime(2020, 1, 1)} + assert Request(date=datetime(2020, 1, 1)).to_pydict(include_default_values=True) == {"date": datetime(2020, 1, 1)} def test_to_json_default_values(): @@ -267,9 +245,7 @@ class TestMessage(betterproto.Message): } # All default values - test = TestMessage().from_dict( - {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False} - ) + test = TestMessage().from_dict({"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}) assert json.loads(test.to_json(include_default_values=True)) == { "someInt": 0, @@ -398,15 +374,11 @@ class TestDatetimeMessage(betterproto.Message): bar: datetime = betterproto.message_field(1) baz: timedelta = betterproto.message_field(2) - test = TestDatetimeMessage().from_dict( - {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"} - ) + test = TestDatetimeMessage().from_dict({"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}) assert test.to_dict() == {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"} - test = TestDatetimeMessage().from_pydict( - {"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)} - ) + test = TestDatetimeMessage().from_pydict({"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)}) assert test.to_pydict() == { "bar": datetime(year=2020, month=1, day=1), diff --git a/tests/test_inputs.py b/tests/test_inputs.py index f114a6ab..6b384421 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -10,7 +10,6 @@ Dict, List, Set, - Tuple, ) import pytest @@ -25,7 +24,6 @@ inputs_path, ) - # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -43,9 +41,7 @@ def __init__( _all = set(get_directories(path)) - {"__pycache__"} _services = services _messages = (_all - services) - {"__pycache__"} - _messages_with_json = { - test for test in _messages if get_test_case_json_data(test) - } + _messages_with_json = {test for test in _messages if get_test_case_json_data(test)} unknown_xfail_tests = xfail - _all if unknown_xfail_tests: @@ -58,10 +54,7 @@ def __init__( @staticmethod def apply_xfail_marks(test_set: Set[str], xfail: Set[str]): - return [ - pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test - for test in test_set - ] + return [pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test for test in test_set] test_cases = TestCases( @@ -133,9 +126,7 @@ def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]: def test_data(request, reset_sys_path): test_case_name = request.param - reference_module_root = os.path.join( - *reference_output_package.split("."), test_case_name - ) + reference_module_root = os.path.join(*reference_output_package.split("."), test_case_name) sys.path.append(reference_module_root) plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}") @@ -186,9 +177,7 @@ def test_message_json(test_data: TestData) -> None: message.from_json(sample.json) message_json = message.to_json(indent=0) - assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( - json.loads(sample.json) - ) + assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(json.loads(sample.json)) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) @@ -204,12 +193,8 @@ def test_binary_compatibility(test_data: TestData) -> None: reference_instance = Parse(sample.json, reference_module().Test()) reference_binary_output = reference_instance.SerializeToString() - plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( - sample.json - ) - plugin_instance_from_binary = plugin_module.Test.FromString( - reference_binary_output - ) + plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(sample.json) + plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output) # Generally this can't be relied on, but here we are aiming to match the # existing Python implementation and aren't doing anything tricky. @@ -218,6 +203,6 @@ def test_binary_compatibility(test_data: TestData) -> None: assert bytes(plugin_instance_from_binary) == reference_binary_output assert plugin_instance_from_json == plugin_instance_from_binary - assert dict_replace_nans( - plugin_instance_from_json.to_dict() - ) == dict_replace_nans(plugin_instance_from_binary.to_dict()) + assert dict_replace_nans(plugin_instance_from_json.to_dict()) == dict_replace_nans( + plugin_instance_from_binary.to_dict() + ) diff --git a/tests/test_pickling.py b/tests/test_pickling.py index 2264192e..f45e7a67 100644 --- a/tests/test_pickling.py +++ b/tests/test_pickling.py @@ -36,9 +36,7 @@ class Fo(betterproto.Message): @dataclass(eq=False, repr=False) class NestedData(betterproto.Message): - struct_foo: Dict[str, "google.Struct"] = betterproto.map_field( - 1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE - ) + struct_foo: Dict[str, "google.Struct"] = betterproto.map_field(1, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE) map_str_any_bar: Dict[str, "google.Any"] = betterproto.map_field( 2, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE ) @@ -51,9 +49,7 @@ class Complex(betterproto.Message): fi: "Fi" = betterproto.message_field(4, group="grp") fo: "Fo" = betterproto.message_field(5, group="grp") nested_data: "NestedData" = betterproto.message_field(6) - mapping: Dict[str, "google.Any"] = betterproto.map_field( - 7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE - ) + mapping: Dict[str, "google.Any"] = betterproto.map_field(7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE) def complex_msg(): @@ -64,11 +60,7 @@ def complex_msg(): struct_foo={ "foo": google.Struct( fields={ - "hello": google.Value( - list_value=google.ListValue( - values=[google.Value(string_value="world")] - ) - ) + "hello": google.Value(list_value=google.ListValue(values=[google.Value(string_value="world")])) } ), }, @@ -91,13 +83,7 @@ def test_pickling_complex_message(): assert msg.is_set("fi") is not True assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi"))) assert msg.mapping["string"].value.decode() == "howdy" - assert ( - msg.nested_data.struct_foo["foo"] - .fields["hello"] - .list_value.values[0] - .string_value - == "world" - ) + assert msg.nested_data.struct_foo["foo"].fields["hello"].list_value.values[0].string_value == "world" def test_recursive_message_defaults(): @@ -113,9 +99,7 @@ def test_recursive_message_defaults(): assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42)) # lazy initialized works modifies the message - assert msg != RecursiveMessage( - name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude") - ) + assert msg != RecursiveMessage(name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")) msg.child = RecursiveMessage(child=RecursiveMessage(name="jude")) assert msg == RecursiveMessage( name="bob", @@ -174,10 +158,4 @@ def use_cache(): assert msg.is_set("fi") is not True assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi"))) assert msg.mapping["string"].value.decode() == "howdy" - assert ( - msg.nested_data.struct_foo["foo"] - .fields["hello"] - .list_value.values[0] - .string_value - == "world" - ) + assert msg.nested_data.struct_foo["foo"].fields["hello"].list_value.values[0].string_value == "world" diff --git a/tests/test_streams.py b/tests/test_streams.py index 3a9ea2df..cf259c8c 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,9 +1,7 @@ -from dataclasses import dataclass from io import BytesIO from pathlib import Path from shutil import which from subprocess import run -from typing import Optional import pytest @@ -16,10 +14,7 @@ repeatedpacked, ) - -oneof_example = oneof.Test().from_dict( - {"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"} -) +oneof_example = oneof.Test().from_dict({"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"}) len_oneof = len(bytes(oneof_example)) @@ -46,9 +41,7 @@ def test_load_varint_too_long(): - with BytesIO( - b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01" - ) as stream, pytest.raises(ValueError): + with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream, pytest.raises(ValueError): betterproto.load_varint(stream) with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream: @@ -86,13 +79,9 @@ def test_dump_varint_file(tmp_path): with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open( streams_path / "message_dump_file_single.expected", "rb" ) as exp_stream: - assert betterproto.load_varint(test_stream) == betterproto.load_varint( - exp_stream - ) + assert betterproto.load_varint(test_stream) == betterproto.load_varint(exp_stream) exp_stream.read(2) - assert betterproto.load_varint(test_stream) == betterproto.load_varint( - exp_stream - ) + assert betterproto.load_varint(test_stream) == betterproto.load_varint(exp_stream) def test_parse_fields(): @@ -160,9 +149,7 @@ def test_message_load_file_multiple(): def test_message_load_too_small(): - with open( - streams_path / "message_dump_file_single.expected", "rb" - ) as stream, pytest.raises(ValueError): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream, pytest.raises(ValueError): oneof.Test().load(stream, len_oneof - 1) @@ -175,9 +162,7 @@ def test_message_load_delimited(): def test_message_load_too_large(): - with open( - streams_path / "message_dump_file_single.expected", "rb" - ) as stream, pytest.raises(ValueError): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream, pytest.raises(ValueError): oneof.Test().load(stream, len_oneof + 1) @@ -252,9 +237,7 @@ def compile_jar(): # Compile the JAR proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"]) if proc_maven.returncode != 0: - pytest.skip( - "Maven compatibility-test.jar build failed (maybe Java version <11?)" - ) + pytest.skip("Maven compatibility-test.jar build failed (maybe Java version <11?)") jar = "tests/streams/java/target/compatibility-test.jar" @@ -361,7 +344,7 @@ def test_infinite_messages(compile_jar, tmp_path): # Write delimited messages to file with open(tmp_path / "py_infinite_messages.out", "wb") as stream: - for x in range(num_messages): + for _ in range(num_messages): oneof_example.dump(stream, True) # Have Java read and return the messages diff --git a/tests/test_typing_compiler.py b/tests/test_typing_compiler.py index 0b859cd6..1fc6f55c 100644 --- a/tests/test_typing_compiler.py +++ b/tests/test_typing_compiler.py @@ -1,5 +1,3 @@ -import pytest - from betterproto.plugin.typing_compiler import ( DirectImportTypingCompiler, NoTyping310TypingCompiler, @@ -19,13 +17,9 @@ def test_direct_import_typing_compiler(): assert compiler.union("str", "int") == "Union[str, int]" assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}} assert compiler.iterable("str") == "Iterable[str]" - assert compiler.imports() == { - "typing": {"Optional", "List", "Dict", "Union", "Iterable"} - } + assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union", "Iterable"}} assert compiler.async_iterable("str") == "AsyncIterable[str]" - assert compiler.imports() == { - "typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"} - } + assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}} assert compiler.async_iterator("str") == "AsyncIterator[str]" assert compiler.imports() == { "typing": { @@ -73,6 +67,4 @@ def test_no_typing_311_typing_compiler(): assert compiler.iterable("str") == "Iterable[str]" assert compiler.async_iterable("str") == "AsyncIterable[str]" assert compiler.async_iterator("str") == "AsyncIterator[str]" - assert compiler.imports() == { - "collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"} - } + assert compiler.imports() == {"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}} diff --git a/tests/test_version.py b/tests/test_version.py index 461f3663..87bbd758 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -4,7 +4,6 @@ from betterproto import __version__ - PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve() diff --git a/tests/util.py b/tests/util.py index 22c4f901..6db92b0a 100644 --- a/tests/util.py +++ b/tests/util.py @@ -18,7 +18,6 @@ Union, ) - os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" root_path = Path(__file__).resolve().parent @@ -53,9 +52,7 @@ async def protoc( plugin_path = Path("src/betterproto/plugin/main.py") if "Win" in platform.system(): - with tempfile.NamedTemporaryFile( - "w", encoding="UTF-8", suffix=".bat", delete=False - ) as tf: + with tempfile.NamedTemporaryFile("w", encoding="UTF-8", suffix=".bat", delete=False) as tf: # See https://stackoverflow.com/a/42622705 tf.writelines( [ @@ -103,13 +100,11 @@ class TestCaseJsonFile: test_name: str file_name: str - def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]): - return self.file_name in non_symmetrical_json.get(self.test_name, tuple()) + def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]) -> bool: + return self.file_name in non_symmetrical_json.get(self.test_name, ()) -def get_test_case_json_data( - test_case_name: str, *json_file_names: str -) -> List[TestCaseJsonFile]: +def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[TestCaseJsonFile]: """ :return: A list of all files found in "{inputs_path}/test_case_name" with names matching @@ -128,18 +123,12 @@ def get_test_case_json_data( if not test_data_file_path.exists(): continue with test_data_file_path.open("r") as fh: - result.append( - TestCaseJsonFile( - fh.read(), test_case_name, test_data_file_path.name.split(".")[0] - ) - ) + result.append(TestCaseJsonFile(fh.read(), test_case_name, test_data_file_path.name.split(".")[0])) return result -def find_module( - module: ModuleType, predicate: Callable[[ModuleType], bool] -) -> Optional[ModuleType]: +def find_module(module: ModuleType, predicate: Callable[[ModuleType], bool]) -> Optional[ModuleType]: """ Recursively search module tree for a module that matches the search predicate. Assumes that the submodules are directories containing __init__.py.