diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 7cfbdbc7..0d59584b 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -147,9 +147,9 @@ def get_comment( - proto_file: "FileDescriptorProto", path: List[int], indent: int = 4 + proto_file: "FileDescriptorProto", + path: List[int], ) -> str: - pad = " " * indent for sci_loc in proto_file.source_code_info.location: if list(sci_loc.path) == path: all_comments = list(sci_loc.leading_detached_comments) @@ -176,12 +176,7 @@ def get_comment( # We don't add this space to the generated file. lines = [line[1:] if line and line[0] == " " else line for line in lines] - # This is a field, message, enum, service, or method - if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: - return f'{pad}"""{lines[0]}"""' - else: - joined = f"\n{pad}".join(lines) - return f'{pad}"""\n{pad}{joined}\n{pad}"""' + return "\n".join(lines) return "" @@ -192,7 +187,6 @@ class ProtoContentBase: source_file: FileDescriptorProto typing_compiler: TypingCompiler path: List[int] - comment_indent: int = 4 parent: Union["betterproto.Message", "OutputTemplate"] __dataclass_fields__: Dict[str, object] @@ -225,9 +219,7 @@ def comment(self) -> str: """Crawl the proto source code and retrieve comments for this object. """ - return get_comment( - proto_file=self.source_file, path=self.path, indent=self.comment_indent - ) + return get_comment(proto_file=self.source_file, path=self.path) @property def deprecated(self) -> bool: @@ -444,7 +436,7 @@ def ready(self) -> None: # Check for new imports self.add_imports_to(self.output_file) - def get_field_string(self, indent: int = 4) -> str: + def get_field_string(self) -> str: """Construct string representation of this field as a field.""" name = f"{self.py_name}" annotations = f": {self.annotation}" @@ -727,7 +719,6 @@ class ServiceMethodCompiler(ProtoContentBase): parent: ServiceCompiler proto_obj: MethodDescriptorProto path: List[int] = PLACEHOLDER - comment_indent: int = 8 def __post_init__(self) -> None: # Add method to service diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 7517def7..44b09f02 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -1,15 +1,19 @@ {% if output_file.enums %}{% for _, enum in output_file.enums|dictsort(by="key") %} class {{ enum.py_name }}(betterproto.Enum): {% if enum.comment %} -{{ enum.comment }} - + """ + {{ enum.comment | indent(4) }} + """ {% endif %} + {% for entry in enum.entries %} {{ entry.name }} = {{ entry.value }} - {% if entry.comment %} -{{ entry.comment }} + {% if entry.comment %} + """ + {{ entry.comment | indent(4) }} + """ + {% endif %} - {% endif %} {% endfor %} {% if output_file.pydantic_dataclasses %} @@ -30,16 +34,20 @@ class {{ enum.py_name }}(betterproto.Enum): {% endif %} class {{ message.py_name }}(betterproto.Message): {% if message.comment %} -{{ message.comment }} - + """ + {{ message.comment | indent(4) }} + """ {% endif %} + {% for field in message.fields %} {{ field.get_field_string() }} - {% if field.comment %} -{{ field.comment }} - - {% endif %} + {% if field.comment %} + """ + {{ field.comment | indent(4) }} + """ + {% endif %} {% endfor %} + {% if not message.fields %} pass {% endif %} @@ -66,11 +74,13 @@ class {{ message.py_name }}(betterproto.Message): {% for _, service in output_file.services|dictsort(by="key") %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} -{{ service.comment }} - + """ + {{ service.comment | indent(4) }} + """ {% elif not service.methods %} pass {% endif %} + {% for method in service.methods %} async def {{ method.py_name }}(self {%- if not method.client_streaming -%} @@ -86,13 +96,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): , metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} -{{ method.comment }} - + """ + {{ method.comment | indent(8) }} + """ {% endif %} + {% if method.proto_obj.options and method.proto_obj.options.deprecated %} warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning) - {% endif %} + {% if method.server_streaming %} {% if method.client_streaming %} async for response in self._stream_stream( @@ -150,8 +162,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% for _, service in output_file.services|dictsort(by="key") %} class {{ service.py_name }}Base(ServiceBase): {% if service.comment %} -{{ service.comment }} - + """ + {{ service.comment | indent(4) }} + """ {% endif %} {% for method in service.methods %} @@ -164,9 +177,11 @@ class {{ service.py_name }}Base(ServiceBase): {%- endif -%} ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} -{{ method.comment }} - + """ + {{ method.comment | indent(8) }} + """ {% endif %} + raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) {% if method.server_streaming %} yield {{ method.py_output_message_type }}() diff --git a/tests/inputs/nestedtwice/test_nestedtwice.py b/tests/inputs/nestedtwice/test_nestedtwice.py index 606467c2..ca0557a7 100644 --- a/tests/inputs/nestedtwice/test_nestedtwice.py +++ b/tests/inputs/nestedtwice/test_nestedtwice.py @@ -22,4 +22,4 @@ ], ) def test_comment(cls, expected_comment): - assert cls.__doc__ == expected_comment + assert cls.__doc__.strip() == expected_comment