Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@


def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
proto_file: "FileDescriptorProto",
path: List[int],
) -> str:
pad = " " * indent
for sci_loc in proto_file.source_code_info.location:
if list(sci_loc.path) == path:
all_comments = list(sci_loc.leading_detached_comments)
Expand All @@ -176,12 +176,7 @@ def get_comment(
# We don't add this space to the generated file.
lines = [line[1:] if line and line[0] == " " else line for line in lines]

# This is a field, message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
return "\n".join(lines)

return ""

Expand All @@ -192,7 +187,6 @@ class ProtoContentBase:
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]

__dataclass_fields__: Dict[str, object]
Expand Down Expand Up @@ -225,9 +219,7 @@ def comment(self) -> str:
"""Crawl the proto source code and retrieve comments
for this object.
"""
return get_comment(
proto_file=self.source_file, path=self.path, indent=self.comment_indent
)
return get_comment(proto_file=self.source_file, path=self.path)

@property
def deprecated(self) -> bool:
Expand Down Expand Up @@ -444,7 +436,7 @@ def ready(self) -> None:
# Check for new imports
self.add_imports_to(self.output_file)

def get_field_string(self, indent: int = 4) -> str:
def get_field_string(self) -> str:
"""Construct string representation of this field as a field."""
name = f"{self.py_name}"
annotations = f": {self.annotation}"
Expand Down Expand Up @@ -727,7 +719,6 @@ class ServiceMethodCompiler(ProtoContentBase):
parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
comment_indent: int = 8

def __post_init__(self) -> None:
# Add method to service
Expand Down
55 changes: 35 additions & 20 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
{% if output_file.enums %}{% for _, enum in output_file.enums|dictsort(by="key") %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}

"""
{{ enum.comment | indent(4) }}
"""
{% endif %}

{% for entry in enum.entries %}
{{ entry.name }} = {{ entry.value }}
{% if entry.comment %}
{{ entry.comment }}
{% if entry.comment %}
"""
{{ entry.comment | indent(4) }}
"""
{% endif %}

{% endif %}
{% endfor %}

{% if output_file.pydantic_dataclasses %}
Expand All @@ -30,16 +34,20 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endif %}
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}

"""
{{ message.comment | indent(4) }}
"""
{% endif %}

{% for field in message.fields %}
{{ field.get_field_string() }}
{% if field.comment %}
{{ field.comment }}

{% endif %}
{% if field.comment %}
"""
{{ field.comment | indent(4) }}
"""
{% endif %}
{% endfor %}

{% if not message.fields %}
pass
{% endif %}
Expand All @@ -66,11 +74,13 @@ class {{ message.py_name }}(betterproto.Message):
{% for _, service in output_file.services|dictsort(by="key") %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}

"""
{{ service.comment | indent(4) }}
"""
{% elif not service.methods %}
pass
{% endif %}

{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
Expand All @@ -86,13 +96,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
{% if method.comment %}
{{ method.comment }}

"""
{{ method.comment | indent(8) }}
"""
{% endif %}

{% if method.proto_obj.options and method.proto_obj.options.deprecated %}
warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning)

{% endif %}

{% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
Expand Down Expand Up @@ -150,8 +162,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for _, service in output_file.services|dictsort(by="key") %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
{{ service.comment }}

"""
{{ service.comment | indent(4) }}
"""
{% endif %}

{% for method in service.methods %}
Expand All @@ -164,9 +177,11 @@ class {{ service.py_name }}Base(ServiceBase):
{%- endif -%}
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}

"""
{{ method.comment | indent(8) }}
"""
{% endif %}

raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
{% if method.server_streaming %}
yield {{ method.py_output_message_type }}()
Expand Down
2 changes: 1 addition & 1 deletion tests/inputs/nestedtwice/test_nestedtwice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
],
)
def test_comment(cls, expected_comment):
assert cls.__doc__ == expected_comment
assert cls.__doc__.strip() == expected_comment
Loading