From 3f1c989b505e91ad4d46da0f89202ffdc6504497 Mon Sep 17 00:00:00 2001 From: Peiyang He Date: Mon, 8 Jun 2026 20:26:26 +0800 Subject: [PATCH] feat(compiler): handle Rust identifier escaping and name collisions --- compiler/fory_compiler/cli.py | 39 + compiler/fory_compiler/frontend/fbs/ast.py | 2 +- compiler/fory_compiler/frontend/fbs/parser.py | 11 +- .../fory_compiler/frontend/fbs/translator.py | 11 +- compiler/fory_compiler/generators/rust.py | 670 ++++++++++++++++-- compiler/fory_compiler/ir/ast.py | 5 +- .../tests/test_cli_flatbuffers_options.py | 2 + .../fory_compiler/tests/test_fbs_frontend.py | 4 + .../tests/test_generated_code.py | 118 ++- .../tests/test_service_codegen.py | 37 +- .../tests/test_service_example.py | 1 + 11 files changed, 809 insertions(+), 91 deletions(-) diff --git a/compiler/fory_compiler/cli.py b/compiler/fory_compiler/cli.py index c1b0f889ac..ed81ad1e17 100644 --- a/compiler/fory_compiler/cli.py +++ b/compiler/fory_compiler/cli.py @@ -129,6 +129,7 @@ def resolve_imports( imported_messages = [] imported_unions = [] resolved_import_files = [] + source_packages: Dict[str, Optional[str]] = {str(file_path): schema.package} for imp in schema.imports: # Resolve import path using search paths @@ -159,6 +160,7 @@ def resolve_imports( imported_enums.extend(imported_schema.enums) imported_messages.extend(imported_schema.messages) imported_unions.extend(imported_schema.unions) + source_packages.update(imported_schema.source_packages) # Create merged schema with imported types first (so they can be referenced) merged_schema = Schema( @@ -173,6 +175,7 @@ def resolve_imports( source_file=schema.source_file, source_format=schema.source_format, resolved_import_files=list(dict.fromkeys(resolved_import_files)), + source_packages=source_packages, ) cache[file_path] = copy.deepcopy(merged_schema) @@ -611,6 +614,8 @@ def compile_file( emit_fdl_path: Optional[Path] = None, resolve_cache: Optional[Dict[Path, Schema]] = None, grpc: bool = False, + *, + generated_outputs: Optional[Dict[Path, Path]] = None, ) -> bool: """Compile a single IDL file with import resolution. @@ -618,7 +623,11 @@ def compile_file( file_path: Path to the IDL file lang_output_dirs: Dictionary mapping language name to output directory import_paths: List of import search paths + generated_outputs: output file path -> source IDL path """ + file_path = file_path.resolve() + if generated_outputs is None: + generated_outputs = {} print(f"Compiling {file_path}...") # Parse and resolve imports @@ -689,6 +698,31 @@ def compile_file( print(f"Error: {e}", file=sys.stderr) return False + if lang == "rust": + # Special error handling for Rust + output_targets: List[Path] = [] + for f in files: + target = (lang_output / f.path).resolve() + # Reject overwriting existing non-generated files + if target.exists() and not is_generated_file(target): + print( + f"Error: Rust output path collision: {target} already exists", + file=sys.stderr, + ) + return False + # Check if distinct source files map to the same output file, e.g. due to naming normalization + previous_source = generated_outputs.get(target) + if previous_source is not None and previous_source != file_path: + print( + "Error: Rust output path collision: " + f"{previous_source} and {file_path} both generate {target}", + file=sys.stderr, + ) + return False + output_targets.append(target) + for target in output_targets: + generated_outputs[target] = file_path + generator.write_files(files) for f in files: @@ -709,6 +743,7 @@ def compile_file_recursive( stack: Set[Path], resolve_cache: Dict[Path, Schema], go_module_root: Optional[Path], + generated_outputs: Dict[Path, Path], grpc: bool = False, ) -> bool: file_path = file_path.resolve() @@ -773,6 +808,7 @@ def compile_file_recursive( stack, resolve_cache, go_module_root, + generated_outputs, grpc, ): stack.remove(file_path) @@ -789,6 +825,7 @@ def compile_file_recursive( emit_fdl_path, resolve_cache, grpc, + generated_outputs=generated_outputs, ) if ok: generated.add(file_path) @@ -868,6 +905,7 @@ def cmd_compile(args: argparse.Namespace) -> int: success = True generated: Set[Path] = set() resolve_cache: Dict[Path, Schema] = {} + generated_outputs: Dict[Path, Path] = {} for file_path in args.files: if not file_path.exists(): print(f"Error: File not found: {file_path}", file=sys.stderr) @@ -887,6 +925,7 @@ def cmd_compile(args: argparse.Namespace) -> int: set(), resolve_cache, None, + generated_outputs, args.grpc, ): success = False diff --git a/compiler/fory_compiler/frontend/fbs/ast.py b/compiler/fory_compiler/frontend/fbs/ast.py index ad3bde9474..f659fc2ad6 100644 --- a/compiler/fory_compiler/frontend/fbs/ast.py +++ b/compiler/fory_compiler/frontend/fbs/ast.py @@ -81,7 +81,7 @@ class FbsUnion: """A FlatBuffers union declaration.""" name: str - types: List[str] = field(default_factory=list) + types: List[FbsTypeName] = field(default_factory=list) attributes: Dict[str, object] = field(default_factory=dict) line: int = 0 column: int = 0 diff --git a/compiler/fory_compiler/frontend/fbs/parser.py b/compiler/fory_compiler/frontend/fbs/parser.py index 67013c26e8..bc03be2eda 100644 --- a/compiler/fory_compiler/frontend/fbs/parser.py +++ b/compiler/fory_compiler/frontend/fbs/parser.py @@ -266,13 +266,20 @@ def parse_union(self) -> FbsUnion: attributes = self.parse_metadata() self.consume(TokenType.LBRACE, "Expected '{' after union name") - types: List[str] = [] + types: List[FbsTypeName] = [] while not self.check(TokenType.RBRACE): if self.check(TokenType.COMMA): self.advance() continue + type_start = self.current() type_name = self.parse_qualified_ident() - types.append(type_name) + types.append( + FbsTypeName( + name=type_name, + line=type_start.line, + column=type_start.column, + ) + ) if self.match(TokenType.COMMA): continue if self.check(TokenType.RBRACE): diff --git a/compiler/fory_compiler/frontend/fbs/translator.py b/compiler/fory_compiler/frontend/fbs/translator.py index 1648bf03d3..6bbef2851f 100644 --- a/compiler/fory_compiler/frontend/fbs/translator.py +++ b/compiler/fory_compiler/frontend/fbs/translator.py @@ -230,19 +230,20 @@ def _translate_field_attributes( def _translate_union(self, fbs_union: FbsUnion) -> Union: fields: List[Field] = [] - for index, type_name in enumerate(fbs_union.types, start=1): + for index, type_ref in enumerate(fbs_union.types, start=1): + type_name = type_ref.name field_name = self._lower_name(type_name) fields.append( Field( name=field_name, field_type=NamedType( type_name, - location=self._location(fbs_union.line, fbs_union.column), + location=self._location(type_ref.line, type_ref.column), ), number=index, - line=fbs_union.line, - column=fbs_union.column, - location=self._location(fbs_union.line, fbs_union.column), + line=type_ref.line, + column=type_ref.column, + location=self._location(type_ref.line, type_ref.column), ) ) return Union( diff --git a/compiler/fory_compiler/generators/rust.py b/compiler/fory_compiler/generators/rust.py index a72642e66f..4569b26e6f 100644 --- a/compiler/fory_compiler/generators/rust.py +++ b/compiler/fory_compiler/generators/rust.py @@ -92,9 +92,456 @@ def primitive_type_name(self, kind: PrimitiveKind) -> str: return temporal_map[kind] return self.PRIMITIVE_MAP[kind] + # Strict and reserved keywords defined in Rust (https://doc.rust-lang.org/reference/keywords.html). + # Weak keywords are intentionally excluded because they are usable outside their special syntax contexts. + RUST_RAW_IDENTIFIER_KEYWORDS = { + "as", + "async", + "await", + "abstract", + "become", + "box", + "break", + "const", + "continue", + "do", + "dyn", + "else", + "enum", + "extern", + "false", + "final", + "fn", + "for", + "gen", + "if", + "impl", + "in", + "let", + "loop", + "macro", + "match", + "mod", + "move", + "mut", + "override", + "priv", + "pub", + "ref", + "return", + "static", + "struct", + "trait", + "true", + "try", + "type", + "typeof", + "unsafe", + "unsized", + "use", + "virtual", + "where", + "while", + "yield", + } + + # Reserved identifiers in Rust (https://doc.rust-lang.org/reference/identifiers.html#railroad-RESERVED_RAW_IDENTIFIER). + # These tokens are invalid even with an `r#` prefix, so escape them by suffixing `_` instead. + RUST_RESERVED_IDENTIFIERS = {"_", "self", "Self", "super", "crate"} + + def sanitize_identifier(self, normalized: str) -> str: + """Escape an already-normalized Rust name.""" + if normalized in self.RUST_RESERVED_IDENTIFIERS: + return f"{normalized}_" + if normalized and normalized[0].isnumeric(): + return f"_{normalized}" # Rust identifiers cannot start with a digit. + if normalized in self.RUST_RAW_IDENTIFIER_KEYWORDS: + return f"r#{normalized}" + return normalized + + def to_rust_snake(self, source: str) -> str: + """Convert an IDL name to a sanitized Rust snake_case identifier.""" + return self.sanitize_identifier(self.to_snake_case(source)) + + def get_top_level_module_identifier(self, package: Optional[str]) -> str: + """Get the Rust module identifier used to reference one schema file.""" + # e.g. `foo.bar` defined in the IDL will be `foo_bar` in the generated Rust code. + module_name = package.replace(".", "_") if package else "generated" + return self.to_rust_snake(module_name) + + def get_type_identifier(self, type_def: object) -> str: + """Get the sanitized identifier for a type declaration or reference from the cache.""" + self._ensure_name_caches(self._schema_for_node(type_def)) + return self._type_identifier_cache[self._cache_key(type_def)] + + def get_module_identifier(self, message: Message) -> str: + """Get the sanitized module name for a message's nested-type scope from the cache.""" + self._ensure_name_caches(self._schema_for_node(message)) + return self._module_identifier_cache[self._cache_key(message)] + + def get_field_identifier(self, message: Message, field: Field) -> str: + """Get the sanitized field name within one message from the cache.""" + self._ensure_name_caches(self._schema_for_node(message)) + return self._field_identifier_cache[self._cache_key(message)][ + self._cache_key(field) + ] + + def get_union_case_identifier(self, union: Union, field: Field) -> str: + """Get the sanitized variant name for one union case from the cache""" + self._ensure_name_caches(self._schema_for_node(union)) + return self._union_case_identifier_cache[self._cache_key(union)][ + self._cache_key(field) + ] + + def _cache_key(self, node: object) -> Tuple[object, ...]: + """Get a cache key for an IR node.""" + # Use the location as the key due to its stability. + location = node.location + return ( + type(node).__name__, + str(Path(location.file).resolve()), + location.line, + location.column, + ) + + def _package_for_source_file(self, file_path: str) -> Optional[str]: + """Get the package name that a file declares.""" + source_key = str(Path(file_path).resolve()) + schema_source_key = str(Path(self.schema.source_file).resolve()) + # `file_path` is the self schema file. + if source_key == schema_source_key: + return self.schema.package + # `file_path` corresponds to an imported schema file. + return self.schema.source_packages[source_key] + + def _schema_for_node(self, node: object) -> Schema: + """Get the schema an IR node belongs to.""" + file_path = node.location.file + source_key = str(Path(file_path).resolve()) + # `node` belongs to the self schema. + if source_key == str(Path(self.schema.source_file).resolve()): + return self.schema + # `node` belongs to an imported schema. + if not hasattr(self, "_source_schema_cache"): + self._source_schema_cache: Dict[str, Schema] = {} + if source_key in self._source_schema_cache: + return self._source_schema_cache[source_key] + enums = [ + enum + for enum in self.schema.enums + if str(Path(enum.location.file).resolve()) == source_key + ] + unions = [ + union + for union in self.schema.unions + if str(Path(union.location.file).resolve()) == source_key + ] + messages = [ + message + for message in self.schema.messages + if str(Path(message.location.file).resolve()) == source_key + ] + services = [ + service + for service in self.schema.services + if str(Path(service.location.file).resolve()) == source_key + ] + if enums or unions or messages or services: + schema = Schema( + package=self._package_for_source_file(file_path), + enums=enums, + messages=messages, + unions=unions, + services=services, + source_file=file_path, + source_format=self.schema.source_format, + ) + self._source_schema_cache[source_key] = schema + return schema + raise ValueError( + f"Rust generator cannot find source schema for " + f"{type(node).__name__} {getattr(node, 'name', '')!r}" + ) + + def _local_top_level_types( + self, schema: Schema + ) -> Tuple[List[Enum], List[Union], List[Message]]: + """Get top-level types that are declared directly in the schema file.""" + schema_source_key = str(Path(schema.source_file).resolve()) + enums = [ + enum + for enum in schema.enums + if str(Path(enum.location.file).resolve()) == schema_source_key + ] + unions = [ + union + for union in schema.unions + if str(Path(union.location.file).resolve()) == schema_source_key + ] + messages = [ + message + for message in schema.messages + if str(Path(message.location.file).resolve()) == schema_source_key + ] + return enums, unions, messages + + def _resolve_message_path(self, schema: Schema, parts: List[str]) -> List[Message]: + """Resolve a dotted message path to the concrete message lineage.""" + lineage: List[Message] = [] + scope = self._local_top_level_types(schema)[2] + for part in parts: + match = next((message for message in scope if message.name == part), None) + if match is None: + return [] + lineage.append(match) + scope = match.nested_messages + return lineage + + def _allocate_scoped_identifier( + self, + normalized_name: str, + used_names: Dict[str, str], + scope: str, + source_name: str, + ) -> str: + """Allocate one sanitized identifier inside a single generated scope. Throw error on collision""" + escaped = self.sanitize_identifier(normalized_name) + if not escaped: + raise ValueError(f"Rust identifier for {source_name!r} in {scope} is empty") + previous_source = used_names.get(escaped) + if previous_source is not None: + raise ValueError( + f"Rust name collision in {scope}: {previous_source!r} and " + f"{source_name!r} both map to Rust identifier {escaped!r}" + ) + used_names[escaped] = source_name + return escaped + + def _allocate_scoped_type_identifiers( + self, type_defs: List[object], scope: str + ) -> None: + """Allocate unique sanitized identifiers for type declarations in the scope and cache the results.""" + used_names: Dict[str, str] = {} + for type_def in type_defs: + self._type_identifier_cache[self._cache_key(type_def)] = ( + self._allocate_scoped_identifier( + self.to_pascal_case(type_def.name), + used_names, + scope, + type_def.name, + ) + ) + + def _allocate_scoped_module_identifiers( + self, messages: List[Message], scope: str + ) -> None: + """Allocate unique sanitized identifiers for nested-type modules in the scope and cache the results.""" + used_names: Dict[str, str] = {} + for message in messages: + self._module_identifier_cache[self._cache_key(message)] = ( + self._allocate_scoped_identifier( + self.to_snake_case(message.name), + used_names, + scope, + message.name, + ) + ) + + def _allocate_scoped_enum_identifiers(self, enum: Enum) -> None: + """Allocate unique sanitized variant names for the generated enum and cache the results.""" + used_names: Dict[str, str] = {} + allocated: Dict[Tuple[object, ...], str] = {} + for value in enum.values: + allocated[self._cache_key(value)] = self._allocate_scoped_identifier( + self.to_pascal_case(self.strip_enum_prefix(enum.name, value.name)), + used_names, + f"enum {enum.name}", + value.name, + ) + self._enum_value_identifier_cache[self._cache_key(enum)] = allocated + + def _allocate_scoped_union_identifiers(self, union: Union) -> None: + """Allocate unique sanitized variant names for the generated union and cache the results.""" + used_names: Dict[str, str] = {} + allocated: Dict[Tuple[object, ...], str] = {} + for field in union.fields: + allocated[self._cache_key(field)] = self._allocate_scoped_identifier( + self.to_pascal_case(field.name), + used_names, + f"union {union.name}", + field.name, + ) + self._union_case_identifier_cache[self._cache_key(union)] = allocated + + def _allocate_scoped_message_identifiers(self, message: Message) -> None: + """Allocate all scoped names that belong to the message.""" + used_fields: Dict[str, str] = {} + field_names: Dict[Tuple[object, ...], str] = {} + for field in message.fields: + field_names[self._cache_key(field)] = self._allocate_scoped_identifier( + self.to_snake_case(field.name), + used_fields, + f"message {message.name} fields", + field.name, + ) + self._field_identifier_cache[self._cache_key(message)] = field_names + nested_types: List[object] = ( + list(message.nested_enums) + + list(message.nested_unions) + + list(message.nested_messages) + ) + self._allocate_scoped_type_identifiers( + nested_types, f"message {message.name} types" + ) + self._allocate_scoped_module_identifiers( + list(message.nested_messages), f"message {message.name} modules" + ) + for nested_enum in message.nested_enums: + self._allocate_scoped_enum_identifiers(nested_enum) + for nested_union in message.nested_unions: + self._allocate_scoped_union_identifiers(nested_union) + for nested_message in message.nested_messages: + self._allocate_scoped_message_identifiers(nested_message) + + def _ensure_name_caches(self, schema: Schema) -> None: + """Construct the naming caches once for a schema file.""" + if not hasattr(self, "_named_schema_ids"): + # Init everything. + self._named_schema_ids: Set[int] = set() + self._type_identifier_cache: Dict[Tuple[object, ...], str] = {} + self._module_identifier_cache: Dict[Tuple[object, ...], str] = {} + self._field_identifier_cache: Dict[ + Tuple[object, ...], Dict[Tuple[object, ...], str] + ] = {} + self._enum_value_identifier_cache: Dict[ + Tuple[object, ...], Dict[Tuple[object, ...], str] + ] = {} + self._union_case_identifier_cache: Dict[ + Tuple[object, ...], Dict[Tuple[object, ...], str] + ] = {} + self._named_service_schema_ids: Set[int] = set() + self._service_trait_identifier_cache: Dict[Tuple[object, ...], str] = {} + self._service_client_module_identifier_cache: Dict[ + Tuple[object, ...], str + ] = {} + self._service_server_module_identifier_cache: Dict[ + Tuple[object, ...], str + ] = {} + self._service_name_constant_identifier_cache: Dict[ + Tuple[object, ...], str + ] = {} + self._rpc_method_identifier_cache: Dict[ + Tuple[object, ...], Dict[Tuple[object, ...], str] + ] = {} + self._rpc_stream_type_identifier_cache: Dict[ + Tuple[object, ...], Dict[Tuple[object, ...], str] + ] = {} + self._rpc_path_constant_identifier_cache: Dict[ + Tuple[object, ...], Dict[Tuple[object, ...], str] + ] = {} + schema_id = id(schema) + if schema_id in self._named_schema_ids: + # Cache exists. + return + enums, unions, messages = self._local_top_level_types(schema) + self._allocate_scoped_type_identifiers( + list(enums) + list(unions) + list(messages), "top-level Rust types" + ) + self._allocate_scoped_module_identifiers( + list(messages), "top-level Rust modules" + ) + for enum in enums: + self._allocate_scoped_enum_identifiers(enum) + for union in unions: + self._allocate_scoped_union_identifiers(union) + for message in messages: + self._allocate_scoped_message_identifiers(message) + self._named_schema_ids.add(schema_id) + def generate(self) -> List[GeneratedFile]: """Generate Rust files for the schema.""" files = [] + if self.options.grpc: + # Allocate and validate identifier naming for gRPC service definition. + self._ensure_name_caches(self.schema) + schema_id = id(self.schema) + if schema_id not in self._named_service_schema_ids: + schema_source_key = str(Path(self.schema.source_file).resolve()) + services = [ + service + for service in self.schema.services + if str(Path(service.location.file).resolve()) == schema_source_key + ] + used_traits: Dict[str, str] = {} + used_modules: Dict[str, str] = {} + used_constants: Dict[str, str] = {} + for service in services: + service_key = self._cache_key(service) + self._service_trait_identifier_cache[service_key] = ( + self._allocate_scoped_identifier( + self.to_pascal_case(service.name), + used_traits, + "Rust gRPC service traits", + service.name, + ) + ) + self._service_client_module_identifier_cache[service_key] = ( + self._allocate_scoped_identifier( + f"{self.to_snake_case(service.name)}_client", + used_modules, + "Rust gRPC service modules", + f"{service.name} client module", + ) + ) + self._service_server_module_identifier_cache[service_key] = ( + self._allocate_scoped_identifier( + f"{self.to_snake_case(service.name)}_server", + used_modules, + "Rust gRPC service modules", + f"{service.name} server module", + ) + ) + self._service_name_constant_identifier_cache[service_key] = ( + self._allocate_scoped_identifier( + f"{self.to_upper_snake_case(service.name)}_SERVICE_NAME", + used_constants, + "Rust gRPC service constants", + service.name, + ) + ) + used_methods: Dict[str, str] = {} + used_stream_types: Dict[str, str] = {} + method_names: Dict[Tuple[object, ...], str] = {} + stream_types: Dict[Tuple[object, ...], str] = {} + path_constants: Dict[Tuple[object, ...], str] = {} + for method in service.methods: + method_key = self._cache_key(method) + method_names[method_key] = self._allocate_scoped_identifier( + self.to_snake_case(method.name), + used_methods, + f"Rust gRPC service {service.name} methods", + method.name, + ) + if method.server_streaming: + stream_types[method_key] = self._allocate_scoped_identifier( + f"{self.to_pascal_case(method.name)}Stream", + used_stream_types, + f"Rust gRPC service {service.name} stream types", + method.name, + ) + path_constants[method_key] = self._allocate_scoped_identifier( + f"{self.to_upper_snake_case(service.name)}_" + f"{self.to_upper_snake_case(method.name)}_PATH", + used_constants, + "Rust gRPC service constants", + f"{service.name}.{method.name}", + ) + self._rpc_method_identifier_cache[service_key] = method_names + self._rpc_stream_type_identifier_cache[service_key] = stream_types + self._rpc_path_constant_identifier_cache[service_key] = ( + path_constants + ) + self._named_service_schema_ids.add(schema_id) # Generate a single module file with all types files.append(self.generate_module()) @@ -103,9 +550,11 @@ def generate(self) -> List[GeneratedFile]: def get_module_name(self) -> str: """Get the Rust module name.""" - if self.package: - return self.package.replace(".", "_") - return "generated" + module_name = self.get_top_level_module_identifier(self.package) + # e.g., when resolving the file for `pub mod r#type`, Rust looks for `type.rs`, not `r#type.rs`. + if module_name.startswith("r#"): + return module_name[2:] + return module_name def is_imported_type(self, type_def: object) -> bool: """Return True if a type definition comes from an imported IDL file.""" @@ -150,53 +599,98 @@ def _load_schema(self, file_path: str) -> Optional[Schema]: return schema def _module_name_for_schema(self, schema: Schema) -> str: - if schema.package: - return schema.package.replace(".", "_") - return "generated" + return self.get_top_level_module_identifier(schema.package) - def _module_name_for_type(self, type_def: object) -> Optional[str]: - location = getattr(type_def, "location", None) - file_path = getattr(location, "file", None) if location else None - schema = self._load_schema(file_path) + def _module_name_for_type(self, type_def: object) -> str: + schema = self._load_schema(type_def.location.file) if schema is None: - return None + return self.get_top_level_module_identifier( + self._package_for_source_file(type_def.location.file) + ) return self._module_name_for_schema(schema) + def _record_imported_module( + self, + module_sources: Dict[str, str], + ordered_modules: List[str], + module: str, + source: str, + ) -> None: + """Record an imported module and reject module-name collisions.""" + previous_source = module_sources.get(module) + if previous_source is not None: + if previous_source != source: + raise ValueError( + f"Rust module name collision: {previous_source!r} and " + f"{source!r} both map to Rust module {module!r}" + ) + return + module_sources[module] = source + ordered_modules.append(module) + def _collect_imported_modules(self) -> List[str]: - modules: Set[str] = set() + modules: Dict[str, str] = {} for type_def in self.schema.enums + self.schema.unions + self.schema.messages: if not self.is_imported_type(type_def): continue module = self._module_name_for_type(type_def) - if module: - modules.add(module) + source = type_def.location.file + previous_source = modules.get(module) + if previous_source is not None and previous_source != source: + raise ValueError( + f"Rust module name collision: {previous_source!r} and " + f"{source!r} both map to Rust module {module!r}" + ) + modules[module] = source ordered: List[str] = [] - used: Set[str] = set() - if self.schema.source_file: - base_dir = Path(self.schema.source_file).resolve().parent - for imp in self.schema.imports: - candidate = (base_dir / imp.path).resolve() - schema = self._load_schema(str(candidate)) - if schema is None: - continue + module_sources: Dict[str, str] = {} + base_dir = Path(self.schema.source_file).resolve().parent + for imp in self.schema.imports: + resolved_path = getattr(imp, "resolved_path", None) + candidate = ( + Path(resolved_path).resolve() + if resolved_path + else (base_dir / imp.path).resolve() + ) + schema = self._load_schema(str(candidate)) + if schema is None: + package = self.schema.source_packages.get(str(candidate)) + if str(candidate) not in self.schema.source_packages: + raise ValueError( + f"Rust generator cannot determine package for import " + f"{imp.path!r} resolved to {str(candidate)!r}" + ) + module = self.get_top_level_module_identifier(package) + else: module = self._module_name_for_schema(schema) - if module in used: - continue - ordered.append(module) - used.add(module) - for module in sorted(modules): - if module in used: - continue - ordered.append(module) + self._record_imported_module( + module_sources, ordered, module, str(candidate) + ) + for module, source in sorted(modules.items()): + self._record_imported_module(module_sources, ordered, module, source) return ordered - def _format_imported_type_name(self, type_name: str, module: str) -> str: - if "." in type_name: - parts = type_name.split(".") - parents = [self.to_snake_case(name) for name in parts[:-1]] - path = "::".join(parents + [self.to_pascal_case(parts[-1])]) + def _format_imported_type_name( + self, + type_name: str, + module: str, + type_def: object, + ) -> str: + type_path = self.schema.resolve_type_name(type_name) + if "." in type_path: + parts = type_path.split(".") + parents: List[str] = [] + schema = self._schema_for_node(type_def) + parent_messages = self._resolve_message_path(schema, parts[:-1]) + if parent_messages: + parents = [ + self.get_module_identifier(parent) for parent in parent_messages + ] + if not parents: + parents = [self.to_rust_snake(name) for name in parts[:-1]] + path = "::".join(parents + [self.get_type_identifier(type_def)]) return f"crate::{module}::{path}" - return f"crate::{module}::{self.to_pascal_case(type_name)}" + return f"crate::{module}::{self.get_type_identifier(type_def)}" def generate_bytes_impl(self, type_name: str) -> List[str]: lines = [] @@ -295,24 +789,31 @@ def get_module_path(self, parent_stack: Optional[List[Message]]) -> str: """Build module path from parent message names.""" if not parent_stack: return "" - return "::".join(self.to_snake_case(parent.name) for parent in parent_stack) + return "::".join(self.get_module_identifier(parent) for parent in parent_stack) - def get_type_path(self, name: str, parent_stack: Optional[List[Message]]) -> str: + def get_type_path( + self, type_def: object, parent_stack: Optional[List[Message]] + ) -> str: """Build a type path for nested types from the root module.""" module_path = self.get_module_path(parent_stack) + name = self.get_type_identifier(type_def) if module_path: return f"{module_path}::{name}" return name def build_relative_type_name( self, - current_parents: List[str], - target_parents: List[str], + current_parents: List[Message], + target_parents: List[Message], type_name: str, ) -> str: """Build a type path relative to the current module.""" - current_parts = [self.to_snake_case(name) for name in current_parents] - target_parts = [self.to_snake_case(name) for name in target_parents] + current_parts = [ + self.get_module_identifier(message) for message in current_parents + ] + target_parts = [ + self.get_module_identifier(message) for message in target_parents + ] common = 0 for left, right in zip(current_parts, target_parts): if left != right: @@ -338,7 +839,7 @@ def generate_enum( """Generate a Rust enum.""" lines = [] - type_name = enum.name + type_name = self.get_type_identifier(enum) # Derive macros lines.append( @@ -352,8 +853,10 @@ def generate_enum( for i, value in enumerate(enum.values): if i == 0: lines.append(" #[default]") - stripped_name = self.strip_enum_prefix(enum.name, value.name) - lines.append(f" {self.to_pascal_case(stripped_name)} = {value.value},") + value_name = self._enum_value_identifier_cache[self._cache_key(enum)][ + self._cache_key(value) + ] + lines.append(f" {value_name} = {value.value},") lines.append("}") @@ -367,8 +870,7 @@ def generate_union( """Generate a Rust tagged union.""" lines: List[str] = [] - if self.to_pascal_case(union.name) != union.name: - lines.append("#[allow(non_camel_case_types)]") + union_name = self.get_type_identifier(union) comment = self.format_type_id_comment(union, "//") if comment: lines.append(comment) @@ -377,12 +879,12 @@ def generate_union( if self.union_supports_trait(union, trait, parent_stack): derives.append(trait) lines.append(f"#[derive({', '.join(derives)})]") - lines.append(f"pub enum {union.name} {{") + lines.append(f"pub enum {union_name} {{") lines.append(" #[fory(unknown)]") lines.append(" Unknown(::fory::UnknownCase),") for index, field in enumerate(union.fields): - variant_name = self.to_pascal_case(field.name) + variant_name = self.get_union_case_identifier(union, field) pointer_type = self.get_field_pointer_type(field) variant_type = self.generate_type( field.field_type, @@ -413,7 +915,7 @@ def generate_union( if union.fields: default_field = union.fields[0] - default_variant = self.to_pascal_case(default_field.name) + default_variant = self.get_union_case_identifier(union, default_field) default_pointer_type = self.get_field_pointer_type(default_field) default_type = self.generate_type( default_field.field_type, @@ -427,7 +929,7 @@ def generate_union( default_type = self.qualify_union_payload_type( default_field.field_type, default_type, default_variant ) - lines.append(f"impl ::std::default::Default for {union.name} {{") + lines.append(f"impl ::std::default::Default for {union_name} {{") lines.append(" fn default() -> Self {") lines.append( f" Self::{default_variant}(<{default_type} as ::fory::ForyDefault>::fory_default())" @@ -436,7 +938,7 @@ def generate_union( lines.append("}") lines.append("") - lines.extend(self.generate_bytes_impl(union.name)) + lines.extend(self.generate_bytes_impl(union_name)) return lines @@ -458,7 +960,7 @@ def generate_message( """Generate a Rust struct.""" lines = [] - type_name = self.to_pascal_case(message.name) + type_name = self.get_type_identifier(message) # Derive macros comment = self.format_type_id_comment(message, "//") @@ -785,7 +1287,7 @@ def rust_string_literal(self, value: str) -> str: def generate_debug_impl(self, message: Message) -> List[str]: """Generate a Debug impl that avoids recursive ref expansion.""" lines: List[str] = [] - type_name = self.to_pascal_case(message.name) + type_name = self.get_type_identifier(message) lines.append(f"impl ::std::fmt::Debug for {type_name} {{") lines.append( " fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {" @@ -802,7 +1304,7 @@ def generate_debug_impl(self, message: Message) -> List[str]: ) lineage = self._lineage_for_message(message) for i, field in enumerate(message.fields): - field_name = self.to_snake_case(field.name) + field_name = self.get_field_identifier(message, field) if i > 0: lines.append(' f.write_str(", ")?;') lines.append( @@ -842,7 +1344,7 @@ def generate_nested_module( lines: List[str] = [] ind = self.indent_str * indent - module_name = self.to_snake_case(message.name) + module_name = self.get_module_identifier(message) lines.append(f"{ind}pub mod {module_name} {{") lines.append(f"{ind}{self.indent_str}use super::*;") lines.append("") @@ -918,7 +1420,7 @@ def generate_field( parent_stack=parent_stack, pointer_type=pointer_type, ) - field_name = self.to_snake_case(field.name) + field_name = self.get_field_identifier(parent_stack[-1], field) lines.append(f"pub {field_name}: {rust_type},") @@ -1026,12 +1528,19 @@ def generate_type( return base_type elif isinstance(field_type, NamedType): - type_name = self.resolve_nested_type_name(field_type.name, parent_stack) - named_type = self.schema.get_type(field_type.name) - if named_type is not None and self.is_imported_type(named_type): + named_type = self.resolve_named_type(field_type.name, parent_stack) + if named_type is None: + raise ValueError(f"Unknown type {field_type.name!r}") + type_name = self.resolve_nested_type_name( + field_type.name, + named_type, + parent_stack, + ) + if self.is_imported_type(named_type): module = self._module_name_for_type(named_type) - if module: - type_name = self._format_imported_type_name(field_type.name, module) + type_name = self._format_imported_type_name( + field_type.name, module, named_type + ) if ref: type_name = f"{pointer_type}<{type_name}>" if nullable: @@ -1104,34 +1613,43 @@ def generate_type( map_type = f"::std::option::Option<{map_type}>" return map_type - return "()" + raise TypeError(f"Unsupported Rust field type: {field_type!r}") def resolve_nested_type_name( self, type_name: str, + type_def: object, parent_stack: Optional[List[Message]] = None, ) -> str: """Resolve nested type names to module-qualified Rust identifiers.""" - current_parents = [msg.name for msg in (parent_stack or [])[:-1]] - if "." in type_name: - parts = type_name.split(".") - target_parents = parts[:-1] - base_name = parts[-1] + current_parents = (parent_stack or [])[:-1] + type_path = self.schema.resolve_type_name(type_name) + if "." in type_path: + parts = type_path.split(".") + schema = self._schema_for_node(type_def) + target_parents = self._resolve_message_path(schema, parts[:-1]) + base_name = self.get_type_identifier(type_def) + if not target_parents: + down = [self.to_rust_snake(name) for name in parts[:-1]] + return "::".join(down + [base_name]) return self.build_relative_type_name( - current_parents, target_parents, self.to_pascal_case(base_name) + current_parents, + target_parents, + base_name, ) + resolved_name = self.get_type_identifier(type_def) if not parent_stack: - return self.to_pascal_case(type_name) + return resolved_name for i in range(len(parent_stack) - 1, -1, -1): message = parent_stack[i] if message.get_nested_type(type_name) is not None: - target_parents = [msg.name for msg in parent_stack[: i + 1]] + target_parents = parent_stack[: i + 1] return self.build_relative_type_name( - current_parents, target_parents, self.to_pascal_case(type_name) + current_parents, target_parents, resolved_name ) - return self.to_pascal_case(type_name) + return resolved_name def field_uses_pointer(self, field: Field) -> bool: if field.ref: @@ -1224,7 +1742,7 @@ def generate_enum_registration( parent_stack: Optional[List[Message]], ): """Generate registration code for an enum.""" - type_name = self.get_type_path(enum.name, parent_stack) + type_name = self.get_type_path(enum, parent_stack) reg_name = self.get_registration_type_name(enum.name, parent_stack) if self.should_register_by_id(enum): @@ -1242,7 +1760,7 @@ def generate_message_registration( parent_stack: Optional[List[Message]], ): """Generate registration code for a message and its nested types.""" - type_name = self.get_type_path(self.to_pascal_case(message.name), parent_stack) + type_name = self.get_type_path(message, parent_stack) reg_name = self.get_registration_type_name(message.name, parent_stack) # Register nested enums first @@ -1278,7 +1796,7 @@ def generate_union_registration( parent_stack: Optional[List[Message]], ): """Generate registration code for a union.""" - type_name = self.get_type_path(union.name, parent_stack) + type_name = self.get_type_path(union, parent_stack) reg_name = self.get_registration_type_name(union.name, parent_stack) if self.should_register_by_id(union): diff --git a/compiler/fory_compiler/ir/ast.py b/compiler/fory_compiler/ir/ast.py index 24360fa019..a03fb2b4d4 100644 --- a/compiler/fory_compiler/ir/ast.py +++ b/compiler/fory_compiler/ir/ast.py @@ -18,7 +18,7 @@ """AST node definitions for FDL.""" from dataclasses import dataclass, field -from typing import List, Optional, Union as TypingUnion +from typing import Dict, List, Optional, Union as TypingUnion from fory_compiler.ir.types import PrimitiveKind @@ -331,6 +331,9 @@ class Schema: source_file: Optional[str] = None source_format: Optional[str] = None resolved_import_files: List[str] = field(default_factory=list) + source_packages: Dict[str, Optional[str]] = field( + default_factory=dict + ) # Source file path -> the package name it declares def __repr__(self) -> str: opts = f", options={len(self.options)}" if self.options else "" diff --git a/compiler/fory_compiler/tests/test_cli_flatbuffers_options.py b/compiler/fory_compiler/tests/test_cli_flatbuffers_options.py index 6cac73bfd2..d36f854076 100644 --- a/compiler/fory_compiler/tests/test_cli_flatbuffers_options.py +++ b/compiler/fory_compiler/tests/test_cli_flatbuffers_options.py @@ -43,6 +43,7 @@ def test_cli_swift_namespace_style_works_for_flatbuffers(tmp_path: Path): fbs_path, {"swift": swift_out}, swift_namespace_style="flatten", + generated_outputs={}, ) assert ok is True @@ -64,6 +65,7 @@ def test_cli_go_nested_type_style_is_accepted_for_flatbuffers(tmp_path: Path): fbs_path, {"go": go_out}, go_nested_type_style="camelcase", + generated_outputs={}, ) assert ok is True diff --git a/compiler/fory_compiler/tests/test_fbs_frontend.py b/compiler/fory_compiler/tests/test_fbs_frontend.py index 0a94d1d198..208e103360 100644 --- a/compiler/fory_compiler/tests/test_fbs_frontend.py +++ b/compiler/fory_compiler/tests/test_fbs_frontend.py @@ -102,6 +102,10 @@ def test_fbs_union_translation(): assert union.name == "Event" assert [f.name for f in union.fields] == ["foo", "bar"] assert [f.number for f in union.fields] == [1, 2] + assert [(f.location.line, f.location.column) for f in union.fields] == [ + (3, 19), + (3, 24), + ] def test_fbs_fory_ref_attributes(): diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index ac706a2f15..1eaa3c93d8 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -23,7 +23,7 @@ import pytest -from fory_compiler.cli import resolve_imports +from fory_compiler.cli import main as foryc_main, resolve_imports from fory_compiler.frontend.fbs import FBSFrontend from fory_compiler.frontend.fdl.lexer import Lexer from fory_compiler.frontend.fdl.parser import Parser @@ -490,8 +490,8 @@ def test_generated_code_map_types_equivalent(): assert_all_languages_equal(schemas) rust_output = render_files(generate_files(schemas["fdl"], RustGenerator)) - assert "RcWeak" in rust_output - assert "Option" in rust_output + assert "::fory::RcWeak" in rust_output + assert "::std::option::Option" in rust_output cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator)) assert "SharedWeak" in cpp_output @@ -849,7 +849,7 @@ def test_generated_code_tree_ref_options_equivalent(): assert_all_languages_equal(schemas) rust_output = render_files(generate_files(schemas["fdl"], RustGenerator)) - assert "ArcWeak" in rust_output + assert "::fory::ArcWeak" in rust_output assert "#[derive(::fory::ForyStruct, Clone, PartialEq, Eq, Default)]" in rust_output cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator)) @@ -1255,3 +1255,113 @@ def test_rust_union_conflicting_payload_uses_self_path(): "Self::Dog(::fory_default())" in rust_output ) assert "Dog(Dog)," not in rust_output + + +def test_rust_escapes_keywords(): + schema = parse_fdl( + dedent( + """ + package demo; + + message type { + string type = 1; + string self = 2; + string crate = 3; + string extern = 4; + string raw = 5; + } + + message _1 { + string value = 1; + } + """ + ) + ) + rust_files = generate_files(schema, RustGenerator) + rust_output = render_files(rust_files) + + assert "demo.rs" in rust_files + assert "pub struct Type {" in rust_output + assert "pub r#type: ::std::string::String," in rust_output + assert "pub self_: ::std::string::String," in rust_output + assert "pub crate_: ::std::string::String," in rust_output + assert "pub r#extern: ::std::string::String," in rust_output + assert "pub raw: ::std::string::String," in rust_output + assert "pub struct _1 {" in rust_output + + +def test_rust_rejects_normalized_name_collisions(): + collision_cases = [ + """ + message foo_bar {} + + message FooBar {} + """, + """ + message Holder { + string fooBar = 1; + string foo_bar = 2; + } + """, + """ + message Holder { + string self = 1; + string self_ = 2; + } + """, + """ + union crate { + string self = 1; + string Self = 2; + } + """, + ] + + for source in collision_cases: + schema = parse_fdl(dedent(source)) + with pytest.raises(ValueError, match="Rust name collision"): + generate_files(schema, RustGenerator) + + +def test_rust_rejects_same_output_path_collisions( + tmp_path: Path, capsys: pytest.CaptureFixture[str] +): + first_fdl = tmp_path / "first.fdl" + second_fdl = tmp_path / "second.fdl" + rust_out = tmp_path / "rust" + first_fdl.write_text( + dedent( + """ + package foo.bar; + + message First { + string value = 1; + } + """ + ) + ) + second_fdl.write_text( + dedent( + """ + package foo_bar; + + message Second { + string value = 1; + } + """ + ) + ) + exit_code = foryc_main( + [ + str(first_fdl), + str(second_fdl), + "--rust_out", + str(rust_out), + "-I", + str(tmp_path), + ] + ) + captured = capsys.readouterr() + + assert exit_code == 1 + assert "Rust output path collision" in captured.err diff --git a/compiler/fory_compiler/tests/test_service_codegen.py b/compiler/fory_compiler/tests/test_service_codegen.py index 413fbf0131..2850974a4c 100644 --- a/compiler/fory_compiler/tests/test_service_codegen.py +++ b/compiler/fory_compiler/tests/test_service_codegen.py @@ -565,8 +565,18 @@ def test_grpc_method_name_collisions_fail(): else: raise AssertionError("Expected Python gRPC method name collision") + rust_generator = RustGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + rust_generator.generate() + except ValueError as e: + assert "Rust name collision" in str(e) + else: + raise AssertionError("Expected Rust gRPC method name collision") + -def test_python_grpc_method_keywords_are_safe_names(): +def test_java_python_grpc_method_keywords_are_safe_names(): schema = parse_fdl( dedent( """ @@ -617,6 +627,29 @@ def test_python_grpc_service_registration_collisions_fail(): raise AssertionError("Expected Python gRPC service registration collision") +def test_rust_grpc_service_module_collisions_fail(): + schema = parse_fdl( + dedent( + """ + package demo.collision; + + service FooBar {} + service FooBAR {} + """ + ) + ) + + generator = RustGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + generator.generate() + except ValueError as e: + assert "Rust name collision" in str(e) + else: + raise AssertionError("Expected Rust gRPC service module collision") + + def test_default_package_java_grpc_output_path_and_service_name(): schema = parse_fdl( dedent( @@ -699,7 +732,7 @@ def test_compile_service_schema_with_grpc_flag(tmp_path: Path): lang_dirs = {} for lang in ("java", "python", "rust", "go", "cpp", "csharp", "swift"): lang_dirs[lang] = tmp_path / lang - ok = compile_file(example_path, lang_dirs, grpc=True) + ok = compile_file(example_path, lang_dirs, grpc=True, generated_outputs={}) assert ok is True for lang, lang_dir in lang_dirs.items(): files = [p for p in lang_dir.rglob("*") if p.is_file()] diff --git a/compiler/fory_compiler/tests/test_service_example.py b/compiler/fory_compiler/tests/test_service_example.py index 8523826e3b..5a6b49c59a 100644 --- a/compiler/fory_compiler/tests/test_service_example.py +++ b/compiler/fory_compiler/tests/test_service_example.py @@ -31,6 +31,7 @@ def test_service_example_compiles_for_java_and_python(tmp_path: Path): "java": java_out, "python": python_out, }, + generated_outputs={}, ) assert ok is True