Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions compiler/fory_compiler/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def resolve_imports(
imported_messages = []
imported_unions = []
resolved_import_files = []
source_packages: Dict[str, Optional[str]] = {str(file_path): schema.package}

for imp in schema.imports:
# Resolve import path using search paths
Expand Down Expand Up @@ -159,6 +160,7 @@ def resolve_imports(
imported_enums.extend(imported_schema.enums)
imported_messages.extend(imported_schema.messages)
imported_unions.extend(imported_schema.unions)
source_packages.update(imported_schema.source_packages)

# Create merged schema with imported types first (so they can be referenced)
merged_schema = Schema(
Expand All @@ -173,6 +175,7 @@ def resolve_imports(
source_file=schema.source_file,
source_format=schema.source_format,
resolved_import_files=list(dict.fromkeys(resolved_import_files)),
source_packages=source_packages,
)

cache[file_path] = copy.deepcopy(merged_schema)
Expand Down Expand Up @@ -611,14 +614,20 @@ def compile_file(
emit_fdl_path: Optional[Path] = None,
resolve_cache: Optional[Dict[Path, Schema]] = None,
grpc: bool = False,
*,
generated_outputs: Optional[Dict[Path, Path]] = None,
) -> bool:
"""Compile a single IDL file with import resolution.

Args:
file_path: Path to the IDL file
lang_output_dirs: Dictionary mapping language name to output directory
import_paths: List of import search paths
generated_outputs: output file path -> source IDL path
"""
file_path = file_path.resolve()
if generated_outputs is None:
generated_outputs = {}
print(f"Compiling {file_path}...")

# Parse and resolve imports
Expand Down Expand Up @@ -689,6 +698,31 @@ def compile_file(
print(f"Error: {e}", file=sys.stderr)
return False

if lang == "rust":
# Special error handling for Rust
output_targets: List[Path] = []
for f in files:
target = (lang_output / f.path).resolve()
# Reject overwriting existing non-generated files
if target.exists() and not is_generated_file(target):
print(
f"Error: Rust output path collision: {target} already exists",
file=sys.stderr,
)
return False
# Check if distinct source files map to the same output file, e.g. due to naming normalization
previous_source = generated_outputs.get(target)
if previous_source is not None and previous_source != file_path:
print(
"Error: Rust output path collision: "
f"{previous_source} and {file_path} both generate {target}",
file=sys.stderr,
)
return False
output_targets.append(target)
for target in output_targets:
generated_outputs[target] = file_path

generator.write_files(files)

for f in files:
Expand All @@ -709,6 +743,7 @@ def compile_file_recursive(
stack: Set[Path],
resolve_cache: Dict[Path, Schema],
go_module_root: Optional[Path],
generated_outputs: Dict[Path, Path],
grpc: bool = False,
) -> bool:
file_path = file_path.resolve()
Expand Down Expand Up @@ -773,6 +808,7 @@ def compile_file_recursive(
stack,
resolve_cache,
go_module_root,
generated_outputs,
grpc,
):
stack.remove(file_path)
Expand All @@ -789,6 +825,7 @@ def compile_file_recursive(
emit_fdl_path,
resolve_cache,
grpc,
generated_outputs=generated_outputs,
)
if ok:
generated.add(file_path)
Expand Down Expand Up @@ -868,6 +905,7 @@ def cmd_compile(args: argparse.Namespace) -> int:
success = True
generated: Set[Path] = set()
resolve_cache: Dict[Path, Schema] = {}
generated_outputs: Dict[Path, Path] = {}
for file_path in args.files:
if not file_path.exists():
print(f"Error: File not found: {file_path}", file=sys.stderr)
Expand All @@ -887,6 +925,7 @@ def cmd_compile(args: argparse.Namespace) -> int:
set(),
resolve_cache,
None,
generated_outputs,
args.grpc,
):
success = False
Expand Down
2 changes: 1 addition & 1 deletion compiler/fory_compiler/frontend/fbs/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class FbsUnion:
"""A FlatBuffers union declaration."""

name: str
types: List[str] = field(default_factory=list)
types: List[FbsTypeName] = field(default_factory=list)
attributes: Dict[str, object] = field(default_factory=dict)
line: int = 0
column: int = 0
Expand Down
11 changes: 9 additions & 2 deletions compiler/fory_compiler/frontend/fbs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,20 @@ def parse_union(self) -> FbsUnion:
attributes = self.parse_metadata()
self.consume(TokenType.LBRACE, "Expected '{' after union name")

types: List[str] = []
types: List[FbsTypeName] = []
while not self.check(TokenType.RBRACE):
if self.check(TokenType.COMMA):
self.advance()
continue
type_start = self.current()
type_name = self.parse_qualified_ident()
types.append(type_name)
types.append(
FbsTypeName(
name=type_name,
line=type_start.line,
column=type_start.column,
)
)
if self.match(TokenType.COMMA):
continue
if self.check(TokenType.RBRACE):
Expand Down
11 changes: 6 additions & 5 deletions compiler/fory_compiler/frontend/fbs/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,20 @@ def _translate_field_attributes(

def _translate_union(self, fbs_union: FbsUnion) -> Union:
fields: List[Field] = []
for index, type_name in enumerate(fbs_union.types, start=1):
for index, type_ref in enumerate(fbs_union.types, start=1):
type_name = type_ref.name
field_name = self._lower_name(type_name)
fields.append(
Field(
name=field_name,
field_type=NamedType(
type_name,
location=self._location(fbs_union.line, fbs_union.column),
location=self._location(type_ref.line, type_ref.column),
),
number=index,
line=fbs_union.line,
column=fbs_union.column,
location=self._location(fbs_union.line, fbs_union.column),
line=type_ref.line,
column=type_ref.column,
location=self._location(type_ref.line, type_ref.column),
)
)
return Union(
Expand Down
Loading
Loading