From e546d9e7c3e2eea7305f382a67b653d6a3198b60 Mon Sep 17 00:00:00 2001 From: Morgan Thomas Date: Mon, 22 Dec 2025 16:36:40 +0000 Subject: [PATCH] 111: import statements --- crates/lean_compiler/src/a_simplify_lang.rs | 64 ++- .../src/b_compile_intermediate.rs | 29 +- crates/lean_compiler/src/c_compile_final.rs | 13 +- crates/lean_compiler/src/grammar.pest | 9 +- crates/lean_compiler/src/ir/instruction.rs | 8 +- crates/lean_compiler/src/lang.rs | 15 +- crates/lean_compiler/src/lib.rs | 12 +- crates/lean_compiler/src/parser/mod.rs | 21 +- .../src/parser/parsers/function.rs | 1 + .../lean_compiler/src/parser/parsers/mod.rs | 29 +- .../src/parser/parsers/program.rs | 132 +++++- crates/lean_compiler/tests/bar.snark | 3 + .../lean_compiler/tests/circular_import.snark | 1 + crates/lean_compiler/tests/foo.snark | 1 + crates/lean_compiler/tests/test_compiler.rs | 429 ++++++++++++++++-- crates/lean_prover/tests/hash_chain.rs | 2 +- crates/lean_prover/tests/test_zkvm.rs | 2 +- crates/lean_vm/src/core/types.rs | 38 +- crates/lean_vm/src/diagnostics/error.rs | 4 +- crates/lean_vm/src/diagnostics/profiler.rs | 10 +- crates/lean_vm/src/diagnostics/stack_trace.rs | 105 +++-- crates/lean_vm/src/execution/context.rs | 6 +- crates/lean_vm/src/execution/runner.rs | 21 +- crates/lean_vm/src/isa/bytecode.rs | 7 +- crates/lean_vm/src/isa/hint.rs | 17 +- crates/rec_aggregation/src/whir_recursion.rs | 5 +- crates/rec_aggregation/src/xmss_aggregate.rs | 5 +- 27 files changed, 805 insertions(+), 184 deletions(-) create mode 100644 crates/lean_compiler/tests/bar.snark create mode 100644 crates/lean_compiler/tests/circular_import.snark create mode 100644 crates/lean_compiler/tests/foo.snark diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 17f2abec..e1c36521 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -6,7 +6,7 @@ use crate::{ Expression, Function, Line, Program, Scope, SimpleExpr, Var, }, }; -use lean_vm::{Boolean, BooleanExpr, SourceLineNumber, Table, TableT}; +use lean_vm::{Boolean, BooleanExpr, FileId, SourceLineNumber, SourceLocation, Table, TableT}; use std::{ collections::{BTreeMap, BTreeSet}, fmt::{Display, Formatter}, @@ -21,6 +21,7 @@ pub struct SimpleProgram { #[derive(Debug, Clone)] pub struct SimpleFunction { pub name: String, + pub file_id: FileId, pub arguments: Vec, pub n_returned_vars: usize, pub instructions: Vec, @@ -146,7 +147,7 @@ pub enum SimpleLine { }, // noop, debug purpose only LocationReport { - location: SourceLineNumber, + location: SourceLocation, }, DebugAssert(BooleanExpr, SourceLineNumber), } @@ -190,6 +191,7 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { let mut array_manager = ArrayManager::default(); let simplified_instructions = simplify_lines( &program.functions, + func.file_id, func.n_returned_vars, &func.body, &mut counters, @@ -209,6 +211,7 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { .collect::>(); let simplified_function = SimpleFunction { name: name.clone(), + file_id: func.file_id, arguments, n_returned_vars: func.n_returned_vars, instructions: simplified_instructions, @@ -600,6 +603,7 @@ impl ArrayManager { #[allow(clippy::too_many_arguments)] fn simplify_lines( functions: &BTreeMap, + file_id: FileId, n_returned_vars: usize, lines: &[Line], counters: &mut Counters, @@ -622,6 +626,7 @@ fn simplify_lines( assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0"); simple_arms.push(simplify_lines( functions, + file_id, n_returned_vars, statements, counters, @@ -812,6 +817,7 @@ fn simplify_lines( let mut array_manager_then = array_manager.clone(); let then_branch_simplified = simplify_lines( functions, + file_id, n_returned_vars, then_branch, counters, @@ -826,6 +832,7 @@ fn simplify_lines( let else_branch_simplified = simplify_lines( functions, + file_id, n_returned_vars, else_branch, counters, @@ -874,6 +881,7 @@ fn simplify_lines( array_manager.valid.clear(); let simplified_body = simplify_lines( functions, + file_id, 0, body, counters, @@ -938,7 +946,10 @@ fn simplify_lines( // Create recursive function body let recursive_func = create_recursive_function( func_name.clone(), - *line_number, + SourceLocation { + line_number: *line_number, + file_id, + }, func_args, iterator.clone(), end_simplified, @@ -1122,7 +1133,12 @@ fn simplify_lines( res.push(SimpleLine::Panic); } Line::LocationReport { location } => { - res.push(SimpleLine::LocationReport { location: *location }); + res.push(SimpleLine::LocationReport { + location: SourceLocation { + line_number: *location, + file_id, + }, + }); } } } @@ -1161,10 +1177,7 @@ fn simplify_expr( ); } } - panic!( - "Const array '{}' can only be accessed with compile-time constant indices", - array_var - ); + panic!("Const array '{array_var}' can only be accessed with compile-time constant indices",); } if let SimpleExpr::Var(array_var) = array @@ -1672,7 +1685,7 @@ fn handle_array_assignment( fn create_recursive_function( name: String, - line_number: SourceLineNumber, + location: SourceLocation, args: Vec, iterator: Var, end: SimpleExpr, @@ -1696,7 +1709,7 @@ fn create_recursive_function( function_name: name.clone(), args: recursive_args, return_data: vec![], - line_number, + line_number: location.line_number, }); body.push(SimpleLine::FunctionRet { return_data: vec![] }); @@ -1713,12 +1726,13 @@ fn create_recursive_function( condition: diff_var.into(), then_branch: body, else_branch: vec![SimpleLine::FunctionRet { return_data: vec![] }], - line_number, + line_number: location.line_number, }, ]; SimpleFunction { name, + file_id: location.file_id, arguments: args, n_returned_vars: 0, instructions, @@ -2109,6 +2123,7 @@ fn handle_const_arguments(program: &mut Program) -> bool { for func in program.functions.values_mut() { if !func.has_const_arguments() { any_changes |= handle_const_arguments_helper( + func.file_id, &mut func.body, &constant_functions, &mut new_functions, @@ -2133,6 +2148,7 @@ fn handle_const_arguments(program: &mut Program) -> bool { if let Some(func) = new_functions.get_mut(&name) { let initial_count = additional_functions.len(); handle_const_arguments_helper( + func.file_id, &mut func.body, &constant_functions, &mut additional_functions, @@ -2169,6 +2185,7 @@ fn handle_const_arguments(program: &mut Program) -> bool { } fn handle_const_arguments_helper( + file_id: FileId, lines: &mut [Line], constant_functions: &BTreeMap, new_functions: &mut BTreeMap, @@ -2228,6 +2245,7 @@ fn handle_const_arguments_helper( const_funct_name.clone(), Function { name: const_funct_name, + file_id, arguments: func .arguments .iter() @@ -2247,15 +2265,29 @@ fn handle_const_arguments_helper( else_branch, .. } => { - changed |= handle_const_arguments_helper(then_branch, constant_functions, new_functions, const_arrays); - changed |= handle_const_arguments_helper(else_branch, constant_functions, new_functions, const_arrays); + changed |= handle_const_arguments_helper( + file_id, + then_branch, + constant_functions, + new_functions, + const_arrays, + ); + changed |= handle_const_arguments_helper( + file_id, + else_branch, + constant_functions, + new_functions, + const_arrays, + ); } Line::ForLoop { body, unroll: _, .. } => { - changed |= handle_const_arguments_helper(body, constant_functions, new_functions, const_arrays); + // TODO we should unroll before const arguments handling + handle_const_arguments_helper(file_id, body, constant_functions, new_functions, const_arrays); } Line::Match { arms, .. } => { for (_, arm) in arms { - changed |= handle_const_arguments_helper(arm, constant_functions, new_functions, const_arrays); + changed |= + handle_const_arguments_helper(file_id, arm, constant_functions, new_functions, const_arrays); } } _ => {} @@ -2565,7 +2597,7 @@ impl SimpleLine { Self::Panic => "panic".to_string(), Self::LocationReport { .. } => Default::default(), Self::DebugAssert(bool, _) => { - format!("debug_assert({})", bool) + format!("debug_assert({bool})") } }; format!("{spaces}{line_str}") diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index a3c944af..02a69811 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -145,6 +145,7 @@ fn compile_function( compiler.args_count = function.arguments.len(); compile_lines( + function.file_id, &Label::function(function.name.clone()), &function.instructions, compiler, @@ -153,6 +154,7 @@ fn compile_function( } fn compile_lines( + file_id: FileId, function_name: &Label, lines: &[SimpleLine], compiler: &mut Compiler, @@ -220,7 +222,8 @@ fn compile_lines( for arm in arms.iter() { compiler.stack_pos = saved_stack_pos; compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); - let arm_instructions = compile_lines(function_name, arm, compiler, Some(end_label.clone()))?; + let arm_instructions = + compile_lines(file_id, function_name, arm, compiler, Some(end_label.clone()))?; compiled_arms.push(arm_instructions); compiler.stack_frame_layout.scopes.pop(); new_stack_pos = new_stack_pos.max(compiler.stack_pos); @@ -257,7 +260,7 @@ fn compile_lines( updated_fp: None, }); - let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?; + let remaining = compile_lines(file_id, function_name, &lines[i + 1..], compiler, final_jump)?; compiler.bytecode.insert(end_label, remaining); compiler.stack_frame_layout.scopes.pop(); @@ -349,14 +352,16 @@ fn compile_lines( let saved_stack_pos = compiler.stack_pos; compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); - let then_instructions = compile_lines(function_name, then_branch, compiler, Some(end_label.clone()))?; + let then_instructions = + compile_lines(file_id, function_name, then_branch, compiler, Some(end_label.clone()))?; let then_stack_pos = compiler.stack_pos; compiler.stack_pos = saved_stack_pos; compiler.stack_frame_layout.scopes.pop(); compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); - let else_instructions = compile_lines(function_name, else_branch, compiler, Some(end_label.clone()))?; + let else_instructions = + compile_lines(file_id, function_name, else_branch, compiler, Some(end_label.clone()))?; compiler.bytecode.insert(if_label, then_instructions); compiler.bytecode.insert(else_label, else_instructions); @@ -364,7 +369,7 @@ fn compile_lines( compiler.stack_frame_layout.scopes.pop(); compiler.stack_pos = compiler.stack_pos.max(then_stack_pos); - let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?; + let remaining = compile_lines(file_id, function_name, &lines[i + 1..], compiler, final_jump)?; compiler.bytecode.insert(end_label, remaining); // It is not necessary to update compiler.stack_size here because the preceding call to // compile_lines should have done so. @@ -438,7 +443,13 @@ fn compile_lines( }); } - instructions.extend(compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?); + instructions.extend(compile_lines( + file_id, + function_name, + &lines[i + 1..], + compiler, + final_jump, + )?); instructions }; @@ -584,7 +595,11 @@ fn compile_lines( left: IntermediateValue::from_simple_expr(&boolean.left, compiler), right: IntermediateValue::from_simple_expr(&boolean.right, compiler), }; - instructions.push(IntermediateInstruction::DebugAssert(boolean_simplified, *line_number)); + let location = SourceLocation { + file_id, + line_number: *line_number, + }; + instructions.push(IntermediateInstruction::DebugAssert(boolean_simplified, location)); } } } diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 44baab53..1d3ec630 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -34,8 +34,9 @@ struct Compiler { pub fn compile_to_low_level_bytecode( mut intermediate_bytecode: IntermediateBytecode, - program: String, - function_locations: BTreeMap, + function_locations: BTreeMap, + source_code: BTreeMap, + filepaths: BTreeMap, ) -> Result { intermediate_bytecode.bytecode.insert( Label::EndProgram, @@ -45,7 +46,10 @@ pub fn compile_to_low_level_bytecode( }], ); - let starting_frame_memory = *intermediate_bytecode.memory_size_per_function.get("main").unwrap(); + let starting_frame_memory = *intermediate_bytecode + .memory_size_per_function + .get("main") + .expect("Missing main function"); let mut hints = BTreeMap::new(); let mut label_to_pc = BTreeMap::new(); @@ -145,8 +149,9 @@ pub fn compile_to_low_level_bytecode( instructions, hints, starting_frame_memory, - program, function_locations, + source_code, + filepaths, }) } diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 222cc9ef..e5a83432 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -1,7 +1,10 @@ WHITESPACE = _{ " " | "\t" | "\n" | "\r" } // Program structure -program = { SOI ~ constant_declaration* ~ function+ ~ EOI } +program = { SOI ~ import_statement* ~ constant_declaration* ~ function* ~ EOI } + +// Imports +import_statement = { "import" ~ filepath ~ ";" } // Constants constant_declaration = { "const" ~ identifier ~ "=" ~ (array_literal | expression) ~ ";" } @@ -108,4 +111,6 @@ constant_value = { number | "public_input_start" | "pointer_to_zero_vector" | "p // Lexical elements identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } -number = @{ ASCII_DIGIT+ } \ No newline at end of file +number = @{ ASCII_DIGIT+ } +filepath = { "\"" ~ filepath_character* ~ "\"" } +filepath_character = { ASCII_ALPHANUMERIC | "-" | "_" | " " | "." | "+" | "/" } diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index ca5c0a32..b9d49619 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -1,7 +1,7 @@ use super::operation::HighLevelOperation; use super::value::{IntermediaryMemOrFpOrConstant, IntermediateValue}; use crate::lang::ConstExpression; -use lean_vm::{BooleanExpr, Operation, SourceLineNumber, Table, TableT}; +use lean_vm::{BooleanExpr, Operation, SourceLocation, Table, TableT}; use std::fmt::{Display, Formatter}; /// Core instruction type for the intermediate representation. @@ -69,9 +69,9 @@ pub enum IntermediateInstruction { }, // noop, debug purpose only LocationReport { - location: SourceLineNumber, + location: SourceLocation, }, - DebugAssert(BooleanExpr, SourceLineNumber), + DebugAssert(BooleanExpr, SourceLocation), } impl IntermediateInstruction { @@ -220,7 +220,7 @@ impl Display for IntermediateInstruction { } Self::LocationReport { .. } => Ok(()), Self::DebugAssert(boolean_expr, _) => { - write!(f, "debug_assert {}", boolean_expr) + write!(f, "debug_assert {boolean_expr}") } } } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index fe5ae910..0dd60e8b 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -6,16 +6,21 @@ use std::fmt::{Display, Formatter}; use utils::ToUsize; use crate::{F, ir::HighLevelOperation}; +pub use lean_vm::{FileId, FunctionName, SourceLocation}; #[derive(Debug, Clone)] pub struct Program { - pub functions: BTreeMap, + pub functions: BTreeMap, pub const_arrays: BTreeMap>, + pub function_locations: BTreeMap, + pub source_code: BTreeMap, + pub filepaths: BTreeMap, } #[derive(Debug, Clone)] pub struct Function { pub name: String, + pub file_id: FileId, pub arguments: Vec<(Var, bool)>, // (name, is_const) pub inlined: bool, pub n_returned_vars: usize, @@ -369,8 +374,8 @@ pub enum AssignmentTarget { impl Display for AssignmentTarget { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::Var(var) => write!(f, "{}", var), - Self::ArrayAccess { array, index } => write!(f, "{}[{}]", array, index), + Self::Var(var) => write!(f, "{var}"), + Self::ArrayAccess { array, index } => write!(f, "{array}[{index}]"), } } } @@ -500,7 +505,7 @@ impl Display for Expression { } Self::MathExpr(math_expr, args) => { let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); - write!(f, "{}({})", math_expr, args_str) + write!(f, "{math_expr}({args_str})") } } } @@ -714,7 +719,7 @@ impl Display for ConstExpression { } Self::MathExpr(math_expr, args) => { let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); - write!(f, "{}({})", math_expr, args_str) + write!(f, "{math_expr}({args_str})") } } } diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 4be13793..9840c98c 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -12,9 +12,12 @@ pub mod ir; mod lang; mod parser; -pub fn compile_program(program: String) -> Bytecode { - let (parsed_program, function_locations) = parse_program(&program).unwrap(); +pub fn compile_program(filepath: &str, program: String) -> Bytecode { + let parsed_program = parse_program(filepath, &program).unwrap(); // println!("Parsed program: {}", parsed_program.to_string()); + let function_locations = parsed_program.function_locations.clone(); + let source_code = parsed_program.source_code.clone(); + let filepaths = parsed_program.filepaths.clone(); let simple_program = simplify_program(parsed_program); // println!("Simplified program: {}", simple_program); let intermediate_bytecode = compile_to_intermediate_bytecode(simple_program).unwrap(); @@ -25,18 +28,19 @@ pub fn compile_program(program: String) -> Bytecode { // println!("{name}: {loc}"); // } /* let compiled = */ - compile_to_low_level_bytecode(intermediate_bytecode, program, function_locations).unwrap() // ; + compile_to_low_level_bytecode(intermediate_bytecode, function_locations, source_code, filepaths).unwrap() // ; // println!("\n\nCompiled Program:\n\n{compiled}"); // compiled } pub fn compile_and_run( + filepath: &str, program: String, (public_input, private_input): (&[F], &[F]), no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory profiler: bool, ) { - let bytecode = compile_program(program); + let bytecode = compile_program(filepath, program); let summary = execute_bytecode( &bytecode, (public_input, private_input), diff --git a/crates/lean_compiler/src/parser/mod.rs b/crates/lean_compiler/src/parser/mod.rs index 19a25b07..c1a4b9a0 100644 --- a/crates/lean_compiler/src/parser/mod.rs +++ b/crates/lean_compiler/src/parser/mod.rs @@ -5,23 +5,4 @@ mod grammar; mod lexer; mod parsers; -pub use grammar::parse_source; -pub use parsers::program::ProgramParser; -pub use parsers::{Parse, ParseContext}; - -use crate::lang::Program; -use crate::parser::error::ParseError; -use std::collections::BTreeMap; - -/// Main entry point for parsing Lean programs. -pub fn parse_program(input: &str) -> Result<(Program, BTreeMap), ParseError> { - // Preprocess source to remove comments - let processed_input = lexer::preprocess_source(input); - - // Parse grammar into AST nodes - let program_pair = parse_source(&processed_input)?; - - // Parse into semantic structures - let mut ctx = ParseContext::new(); - ProgramParser.parse(program_pair, &mut ctx) -} +pub use parsers::program::parse_program; diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 739b0d89..219a0480 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -55,6 +55,7 @@ impl Parse for FunctionParser { Ok(Function { name, + file_id: ctx.current_file_id, arguments, inlined, n_returned_vars, diff --git a/crates/lean_compiler/src/parser/parsers/mod.rs b/crates/lean_compiler/src/parser/parsers/mod.rs index 9adaaa6c..3ac446d3 100644 --- a/crates/lean_compiler/src/parser/parsers/mod.rs +++ b/crates/lean_compiler/src/parser/parsers/mod.rs @@ -1,8 +1,9 @@ +use crate::lang::FileId; use crate::parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, }; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; pub mod expression; pub mod function; @@ -26,14 +27,29 @@ pub struct ParseContext { pub const_arrays: BTreeMap>, /// Counter for generating unique trash variable names pub trash_var_count: usize, + /// Filepath of the file we are currently parsing + pub current_filepath: String, + /// Source code of the file we are currently parsing + pub current_source_code: String, + /// File ID of the file we are currently parsing + pub current_file_id: FileId, + /// Absolute filepaths imported so far (also includes the root filepath) + pub imported_filepaths: BTreeSet, + /// Next unused file ID + pub next_file_id: usize, } impl ParseContext { - pub const fn new() -> Self { + pub fn new(current_filepath: &str, current_source_code: &str) -> Self { Self { constants: BTreeMap::new(), const_arrays: BTreeMap::new(), trash_var_count: 0, + current_filepath: current_filepath.to_string(), + current_file_id: 0, + imported_filepaths: BTreeSet::new(), + current_source_code: current_source_code.to_string(), + next_file_id: 1, } } @@ -78,11 +94,18 @@ impl ParseContext { self.trash_var_count += 1; format!("@trash_{}", self.trash_var_count) } + + /// Returns a fresh file id. + pub fn get_next_file_id(&mut self) -> FileId { + let file_id = self.next_file_id; + self.next_file_id += 1; + file_id + } } impl Default for ParseContext { fn default() -> Self { - Self::new() + Self::new("", "") } } diff --git a/crates/lean_compiler/src/parser/parsers/program.rs b/crates/lean_compiler/src/parser/parsers/program.rs index 19a46128..a1500894 100644 --- a/crates/lean_compiler/src/parser/parsers/program.rs +++ b/crates/lean_compiler/src/parser/parsers/program.rs @@ -1,36 +1,82 @@ use super::function::FunctionParser; use super::literal::ConstantDeclarationParser; -use super::{Parse, ParseContext, ParsedConstant}; use crate::{ - lang::Program, + lang::{Program, SourceLocation}, parser::{ - error::{ParseResult, SemanticError}, - grammar::{ParsePair, Rule}, + error::{ParseError, ParseResult, SemanticError}, + grammar::{ParsePair, Rule, parse_source}, + lexer, + parsers::{Parse, ParseContext, ParsedConstant, next_inner_pair}, }, }; use std::collections::BTreeMap; +use std::path::Path; /// Parser for complete programs. pub struct ProgramParser; -impl Parse<(Program, BTreeMap)> for ProgramParser { - fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult<(Program, BTreeMap)> { - let mut ctx = ParseContext::new(); +impl Parse for ProgramParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut functions = BTreeMap::new(); let mut function_locations = BTreeMap::new(); + let mut source_code = BTreeMap::new(); + let mut filepaths = BTreeMap::new(); + let file_id = ctx.get_next_file_id(); + ctx.current_file_id = file_id; + filepaths.insert(file_id, ctx.current_filepath.clone()); + source_code.insert(file_id, ctx.current_source_code.clone()); for item in pair.into_inner() { match item.as_rule() { Rule::constant_declaration => { - let (name, value) = ConstantDeclarationParser.parse(item, &mut ctx)?; + let (name, value) = ConstantDeclarationParser.parse(item, ctx)?; match value { ParsedConstant::Scalar(v) => ctx.add_constant(name, v)?, ParsedConstant::Array(arr) => ctx.add_const_array(name, arr)?, } } + Rule::import_statement => { + // Visit the imported file and parse it into the context + // and program; also keep track of which files have been + // imported and do not import the same file twice. + let filepath = ImportStatementParser.parse(item, ctx)?; + let filepath = Path::new(&ctx.current_filepath) + .parent() + .expect("Empty filepath") + .join(filepath) + .to_str() + .expect("Invalid UTF-8 in filepath") + .to_string(); + if !ctx.imported_filepaths.contains(&filepath) { + let saved_filepath = ctx.current_filepath.clone(); + let saved_file_id = ctx.current_file_id; + ctx.current_filepath = filepath.clone(); + ctx.imported_filepaths.insert(filepath.clone()); + let file_id = ctx.get_next_file_id(); + ctx.current_file_id = file_id; + filepaths.insert(file_id, filepath.clone()); + let input = std::fs::read_to_string(filepath.clone()).map_err(|_| { + SemanticError::with_context( + format!("Imported file not found: {filepath}"), + "import declaration", + ) + })?; + source_code.insert(file_id, input.clone()); + let subprogram = parse_program_helper(filepath.as_str(), input.as_str(), ctx)?; + functions.extend(subprogram.functions); + function_locations.extend(subprogram.function_locations); + source_code.extend(subprogram.source_code); + filepaths.extend(subprogram.filepaths); + ctx.current_filepath = saved_filepath; + ctx.current_file_id = saved_file_id; + // It is unnecessary to save and restore current_source_code because it will not + // be referenced again for the same file. + } + } Rule::function => { - let location = item.line_col().0; - let function = FunctionParser.parse(item, &mut ctx)?; + let line_number = item.line_col().0; + let location = SourceLocation { file_id, line_number }; + let function = FunctionParser.parse(item, ctx)?; let name = function.name.clone(); function_locations.insert(location, name.clone()); @@ -48,12 +94,66 @@ impl Parse<(Program, BTreeMap)> for ProgramParser { } } - Ok(( - Program { - functions, - const_arrays: ctx.const_arrays, - }, + Ok(Program { + functions, + const_arrays: ctx.const_arrays.clone(), function_locations, - )) + filepaths, + source_code, + }) + } +} + +/// Parser for import statements. +pub struct ImportStatementParser; + +impl Parse for ImportStatementParser { + fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult { + let mut inner = pair.into_inner(); + let item = next_inner_pair(&mut inner, "filepath")?; + match item.as_rule() { + Rule::filepath => { + let inner = item.into_inner(); + let mut filepath = String::new(); + for item in inner { + match item.as_rule() { + Rule::filepath_character => { + filepath.push_str(item.as_str()); + } + _ => { + return Err(SemanticError::with_context( + format!("Expected a filepath character, got: {}", item.as_str()), + "filepath character", + ) + .into()); + } + } + } + Ok(filepath) + } + _ => Err( + SemanticError::with_context(format!("Expected a filepath, got: {}", item.as_str()), "filepath").into(), + ), + } } } + +fn parse_program_helper(filepath: &str, input: &str, ctx: &mut ParseContext) -> Result { + // Preprocess source to remove comments + let processed_input = lexer::preprocess_source(input); + + // Parse grammar into AST nodes + let program_pair = parse_source(&processed_input)?; + + // Parse into semantic structures + ctx.current_filepath = filepath.to_string(); + ctx.current_source_code = input.to_string(); + ctx.imported_filepaths.insert(filepath.to_string()); + ProgramParser.parse(program_pair, ctx) +} + +pub fn parse_program(filepath: &str, input: &str) -> Result { + let mut ctx = ParseContext::new(filepath, input); + ctx.imported_filepaths.insert(filepath.to_string()); + parse_program_helper(filepath, input, &mut ctx) +} diff --git a/crates/lean_compiler/tests/bar.snark b/crates/lean_compiler/tests/bar.snark new file mode 100644 index 00000000..563193a7 --- /dev/null +++ b/crates/lean_compiler/tests/bar.snark @@ -0,0 +1,3 @@ +fn bar(x) -> 1 { + return x * 2; +} diff --git a/crates/lean_compiler/tests/circular_import.snark b/crates/lean_compiler/tests/circular_import.snark new file mode 100644 index 00000000..e7df5cad --- /dev/null +++ b/crates/lean_compiler/tests/circular_import.snark @@ -0,0 +1 @@ +import "circular_import.snark"; diff --git a/crates/lean_compiler/tests/foo.snark b/crates/lean_compiler/tests/foo.snark new file mode 100644 index 00000000..41987a0a --- /dev/null +++ b/crates/lean_compiler/tests/foo.snark @@ -0,0 +1 @@ +const FOO = 3; diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 9dc6eeac..f69ec23a 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -21,7 +21,13 @@ fn test_duplicate_function_name() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -35,7 +41,13 @@ fn test_duplicate_constant_name() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -51,7 +63,13 @@ fn test_wrong_n_returned_vars_1() { return 0; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -67,7 +85,13 @@ fn test_wrong_n_returned_vars_2() { return 0, 1; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -86,7 +110,13 @@ fn test_no_return() { return 0; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -106,7 +136,13 @@ fn test_assumed_return() { } } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -127,7 +163,13 @@ fn test_fibonacci_program() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -145,7 +187,13 @@ fn test_edge_case_0() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -158,7 +206,13 @@ fn test_edge_case_1() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -176,7 +230,13 @@ fn test_edge_case_2() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -192,7 +252,13 @@ fn test_decompose_bits() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -208,7 +274,13 @@ fn test_unroll() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -220,7 +292,13 @@ fn test_rev_unroll() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -240,7 +318,13 @@ fn test_mini_program_0() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -279,7 +363,13 @@ fn test_mini_program_1() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -307,7 +397,13 @@ fn test_mini_program_2() { return sum, product; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -335,6 +431,7 @@ fn test_mini_program_3() { "#; let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); compile_and_run( + "", program.to_string(), (&public_input, &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, @@ -362,6 +459,7 @@ fn test_mini_program_4() { "#; let public_input: [F; 24] = (0..24).map(F::new).collect::>().try_into().unwrap(); compile_and_run( + "", program.to_string(), (&public_input, &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, @@ -445,7 +543,13 @@ fn test_inlined() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -465,7 +569,13 @@ fn test_inlined_2() { } } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -517,7 +627,13 @@ fn test_match() { return x * x * x * x * x * x; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -540,7 +656,13 @@ fn test_match_shrink() { return x * x; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } // #[test] @@ -581,7 +703,13 @@ fn test_const_functions_calling_const_functions() { } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -604,7 +732,13 @@ fn test_inline_functions_calling_inline_functions() { } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -631,7 +765,13 @@ fn test_nested_inline_functions() { } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -654,7 +794,13 @@ fn test_const_and_nonconst_malloc_sharing_name() { } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -669,7 +815,13 @@ fn test_debug_assert_eq() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[should_panic] @@ -683,7 +835,13 @@ fn test_debug_assert_eq_fail() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[should_panic] @@ -697,7 +855,13 @@ fn test_debug_assert_not_eq_fail() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[should_panic] @@ -711,7 +875,13 @@ fn test_debug_assert_lt_fail() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -727,7 +897,13 @@ fn test_next_multiple_of() { return next_multiple_of(n, n) * 2; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -764,7 +940,13 @@ fn test_const_array() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -782,7 +964,13 @@ fn test_const_malloc_end_iterator_loop() { return; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -812,7 +1000,13 @@ fn test_array_return_targets() { return 42, 99; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -836,7 +1030,13 @@ fn test_array_return_targets_with_expressions() { return n, n * 2; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -879,7 +1079,13 @@ fn intertwined_unrolled_loops_and_const_function_arguments() { return buff[4]; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } #[test] @@ -902,5 +1108,158 @@ fn test_const_fibonacci() { return a + b; } "#; - compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +#[should_panic] +fn test_undefined_import() { + let program = r#" + import "asdfasdfadsfasdf.snark"; + + fn main() { + return; + } + "#; + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +#[should_panic] +fn test_imported_function_name_clash() { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let self_path = format!("{manifest_dir}/tests/test_compiler.rs"); + let program = r#" + import "bar.snark"; + import "foo.snark"; + + fn bar() { + return; + } + + fn main() { + return; + } + "#; + compile_and_run( + self_path.as_str(), + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +#[should_panic] +fn test_imported_constant_name_clash() { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let self_path = format!("{manifest_dir}/tests/test_compiler.rs"); + let program = r#" + import "bar.snark"; + import "foo.snark"; + + const FOO = 5; + + fn main() { + return; + } + "#; + compile_and_run( + self_path.as_str(), + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +fn test_double_import_tolerance() { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let self_path = format!("{manifest_dir}/tests/test_compiler.rs"); + let program = r#" + import "foo.snark"; + import "foo.snark"; + + fn main() { + return; + } + "#; + compile_and_run( + self_path.as_str(), + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +fn test_circular_import_tolerance() { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let self_path = format!("{manifest_dir}/tests/test_compiler.rs"); + let program = r#" + import "circular_import.snark"; + + fn main() { + return; + } + "#; + compile_and_run( + self_path.as_str(), + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +#[should_panic] +fn test_no_main() { + let program = r#" + "#; + compile_and_run( + "", + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + +#[test] +fn test_imports() { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let self_path = format!("{manifest_dir}/tests/test_compiler.rs"); + let program = r#" + import "bar.snark"; + import "foo.snark"; + + fn main() { + x = bar(FOO); + assert x == 6; + return; + } + "#; + compile_and_run( + self_path.as_str(), + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); } diff --git a/crates/lean_prover/tests/hash_chain.rs b/crates/lean_prover/tests/hash_chain.rs index bd8821e0..c56fcf45 100644 --- a/crates/lean_prover/tests/hash_chain.rs +++ b/crates/lean_prover/tests/hash_chain.rs @@ -56,7 +56,7 @@ fn benchmark_poseidon_chain() { let private_input = vec![]; utils::init_tracing(); - let bytecode = compile_program(program_str); + let bytecode = compile_program("", program_str); let no_vec_runtime_memory = execute_bytecode( &bytecode, (&public_input, &private_input), diff --git a/crates/lean_prover/tests/test_zkvm.rs b/crates/lean_prover/tests/test_zkvm.rs index 0c1239e1..06b31f39 100644 --- a/crates/lean_prover/tests/test_zkvm.rs +++ b/crates/lean_prover/tests/test_zkvm.rs @@ -178,7 +178,7 @@ fn test_zk_vm_helper( merkle_path_hints: VecDeque>, ) { utils::init_tracing(); - let bytecode = compile_program(program_str.to_string()); + let bytecode = compile_program("", program_str.to_string()); let time = std::time::Instant::now(); let (proof, summary) = prove_execution( &bytecode, diff --git a/crates/lean_vm/src/core/types.rs b/crates/lean_vm/src/core/types.rs index 8538abb6..2948f3b5 100644 --- a/crates/lean_vm/src/core/types.rs +++ b/crates/lean_vm/src/core/types.rs @@ -1,4 +1,6 @@ +use derive_more::Display; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; +use std::cmp::Ordering; /// Base field type for VM operations pub type F = KoalaBear; @@ -6,7 +8,7 @@ pub type F = KoalaBear; /// Extension field type for VM operations pub type EF = QuinticExtensionFieldKB; -/// Location in source code for debugging +/// Line number in source code for debugging pub type SourceLineNumber = usize; /// Bytecode address (i.e., a value of the program counter) @@ -14,3 +16,37 @@ pub type CodeAddress = usize; /// Memory address pub type MemoryAddress = usize; + +/// Source code function name +pub type FunctionName = String; + +/// Unique identifier for a file in a compilation +pub type FileId = usize; + +/// Location in source code +#[derive(Display, Hash, PartialEq, Eq, Debug, Clone, Copy)] +#[display("{}:{}", file_id, line_number)] +pub struct SourceLocation { + pub file_id: FileId, + pub line_number: SourceLineNumber, +} + +fn cmp_source_location(a: &SourceLocation, b: &SourceLocation) -> Ordering { + match a.file_id.cmp(&b.file_id) { + Ordering::Less => Ordering::Less, + Ordering::Greater => Ordering::Greater, + Ordering::Equal => a.line_number.cmp(&b.line_number), + } +} + +impl PartialOrd for SourceLocation { + fn partial_cmp(&self, other: &SourceLocation) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SourceLocation { + fn cmp(&self, other: &SourceLocation) -> Ordering { + cmp_source_location(self, other) + } +} diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index 78885a20..bb1d62bc 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -1,4 +1,4 @@ -use crate::core::F; +use crate::core::{F, SourceLocation}; use crate::diagnostics::profiler::MemoryProfile; use crate::execution::Memory; use crate::{TableTrace, error}; @@ -32,7 +32,7 @@ pub enum RunnerError { PCOutOfBounds, #[error("DebugAssert failed: {0} at line {1}")] - DebugAssertFailed(String, usize), + DebugAssertFailed(String, SourceLocation), } pub type VMResult = Result; diff --git a/crates/lean_vm/src/diagnostics/profiler.rs b/crates/lean_vm/src/diagnostics/profiler.rs index 3451328b..6883b90a 100644 --- a/crates/lean_vm/src/diagnostics/profiler.rs +++ b/crates/lean_vm/src/diagnostics/profiler.rs @@ -3,13 +3,13 @@ use std::ops::Range; use utils::pretty_integer; -use crate::core::{Label, MemoryAddress}; -use crate::stack_trace::find_function_for_line; +use crate::core::{Label, MemoryAddress, SourceLocation}; +use crate::stack_trace::find_function_for_location; use crate::{ExecutionHistory, NONRESERVED_PROGRAM_INPUT_START}; pub(crate) fn profiling_report( instructions: &ExecutionHistory, - function_locations: &BTreeMap, + function_locations: &BTreeMap, ) -> String { #[derive(Default, Clone)] struct FunctionStats { @@ -22,8 +22,8 @@ pub(crate) fn profiling_report( let mut call_stack: Vec = Vec::new(); let mut prev_function_name = String::new(); - for (&line_num, &cycle_count) in instructions.lines.iter().zip(&instructions.lines_cycles) { - let (_, current_function_name) = find_function_for_line(line_num, function_locations); + for (&location, &cycle_count) in instructions.lines.iter().zip(&instructions.lines_cycles) { + let (_, current_function_name) = find_function_for_location(location, function_locations); if prev_function_name != current_function_name { if let Some(pos) = call_stack.iter().position(|f| f == ¤t_function_name) { diff --git a/crates/lean_vm/src/diagnostics/stack_trace.rs b/crates/lean_vm/src/diagnostics/stack_trace.rs index c3b4487e..77ad976a 100644 --- a/crates/lean_vm/src/diagnostics/stack_trace.rs +++ b/crates/lean_vm/src/diagnostics/stack_trace.rs @@ -2,30 +2,46 @@ use std::collections::BTreeMap; use colored::Colorize; -use crate::SourceLineNumber; +use crate::{FileId, FunctionName, SourceLineNumber, SourceLocation}; const STACK_TRACE_MAX_LINES_PER_FUNCTION: usize = 5; pub(crate) fn pretty_stack_trace( - source_code: &str, - instructions: &[SourceLineNumber], // SourceLineNumber = usize - function_locations: &BTreeMap, + source_code: &BTreeMap, + instructions: &[SourceLocation], + function_locations: &BTreeMap, + filepaths: &BTreeMap, last_pc: usize, ) -> String { - let source_lines: Vec<&str> = source_code.lines().collect(); + let mut source_locations: BTreeMap = BTreeMap::new(); + for (f_id, src) in source_code.iter() { + for (i, line) in src.lines().enumerate() { + source_locations.insert( + SourceLocation { + file_id: *f_id, + line_number: i, + }, + line, + ); + } + } let mut result = String::new(); - let mut call_stack: Vec<(usize, String)> = Vec::new(); // (line_number, function_name) - let mut prev_function_line = usize::MAX; + let mut call_stack: Vec<(SourceLocation, String)> = Vec::new(); + let mut prev_function_location: Option = None; let mut skipped_lines: usize = 0; // Track skipped lines for current function result.push_str("╔═════════════════════════════════════════════════════════════════════════╗\n"); result.push_str("║ STACK TRACE ║\n"); result.push_str("╚═════════════════════════════════════════════════════════════════════════╝\n\n"); - for (idx, &line_num) in instructions.iter().enumerate() { - let (current_function_line, current_function_name) = find_function_for_line(line_num, function_locations); + for (idx, &location) in instructions.iter().enumerate() { + let (current_function_location, current_function_name) = + find_function_for_location(location, function_locations); + let current_filepath = filepaths + .get(¤t_function_location.file_id) + .expect("Undefined FileId"); - if prev_function_line != current_function_line { + if prev_function_location != Some(current_function_location) { assert_eq!(skipped_lines, 0); // Check if we're returning to a previous function or calling a new one @@ -39,13 +55,13 @@ pub(crate) fn pretty_stack_trace( skipped_lines = 0; } else { // Add the new function to the stack - call_stack.push((line_num, current_function_name.clone())); + call_stack.push((location, current_function_name.clone())); let indent = "│ ".repeat(call_stack.len() - 1); result.push_str(&format!( - "{}├─ {} (line {})\n", + "{}├─ {} ({current_filepath}:{})\n", indent, current_function_name.blue(), - current_function_line + current_function_location.line_number, )); skipped_lines = 0; } @@ -57,8 +73,12 @@ pub(crate) fn pretty_stack_trace( true } else { // Count remaining lines in this function - let remaining_in_function = - count_remaining_lines_in_function(idx, instructions, function_locations, current_function_line); + let remaining_in_function = count_remaining_lines_in_function( + idx, + instructions, + function_locations, + current_function_location.line_number, + ); remaining_in_function < STACK_TRACE_MAX_LINES_PER_FUNCTION }; @@ -72,23 +92,30 @@ pub(crate) fn pretty_stack_trace( } let indent = "│ ".repeat(call_stack.len()); - let code_line = source_lines.get(line_num.saturating_sub(1)).unwrap().trim(); + let location = SourceLocation { + file_id: location.file_id, + line_number: location.line_number.saturating_sub(1), + }; + let code_line = source_locations.get(&location).unwrap().trim(); if idx == instructions.len() - 1 { result.push_str(&format!( "{}├─ {} {}\n", indent, - format!("line {line_num}:").red(), + format!("{current_filepath}:{}:", location.line_number).red(), code_line )); } else { - result.push_str(&format!("{indent}├─ line {line_num}: {code_line}\n")); + result.push_str(&format!( + "{indent}├─ {current_filepath}:{}: {code_line}\n", + location.line_number + )); } } else { skipped_lines += 1; } - prev_function_line = current_function_line; + prev_function_location = Some(current_function_location); } // Add summary @@ -97,11 +124,25 @@ pub(crate) fn pretty_stack_trace( if !call_stack.is_empty() { result.push_str("\nCall stack:\n"); - for (i, (line, func)) in call_stack.iter().enumerate() { + for (i, (location, func)) in call_stack.iter().enumerate() { + let filepath = filepaths.get(&location.file_id).expect("Undefined FileId"); if i + 1 == call_stack.len() { - result.push_str(&format!(" {}. {} (line {}, pc {})\n", i + 1, func, line, last_pc)); + result.push_str(&format!( + " {}. {} ({}:{}, pc {})\n", + i + 1, + func, + filepath, + location.line_number, + last_pc + )); } else { - result.push_str(&format!(" {}. {} (line {})\n", i + 1, func, line)); + result.push_str(&format!( + " {}. {} ({}:{})\n", + i + 1, + func, + filepath, + location.line_number + )); } } } @@ -109,25 +150,27 @@ pub(crate) fn pretty_stack_trace( result } -pub(crate) fn find_function_for_line(line_num: usize, function_locations: &BTreeMap) -> (usize, String) { +pub(crate) fn find_function_for_location( + location: SourceLocation, + function_locations: &BTreeMap, +) -> (SourceLocation, String) { function_locations - .range(..=line_num) + .range(..=location) .next_back() - .map(|(line, func_name)| (*line, func_name.clone())) - .unwrap() + .map(|(location, func_name)| (*location, func_name.clone())) + .unwrap_or_else(|| panic!("Did not find function for location: {location}")) } fn count_remaining_lines_in_function( current_idx: usize, - instructions: &[SourceLineNumber], - function_locations: &BTreeMap, + instructions: &[SourceLocation], + function_locations: &BTreeMap, current_function_line: usize, ) -> usize { let mut count = 0; - for &instruction in instructions.iter().skip(current_idx + 1) { - let line_num = instruction; - let func_line = find_function_for_line(line_num, function_locations).0; + for &location in instructions.iter().skip(current_idx + 1) { + let func_line = find_function_for_location(location, function_locations).0.line_number; if func_line != current_function_line { break; diff --git a/crates/lean_vm/src/execution/context.rs b/crates/lean_vm/src/execution/context.rs index b30be5eb..62abab4e 100644 --- a/crates/lean_vm/src/execution/context.rs +++ b/crates/lean_vm/src/execution/context.rs @@ -1,9 +1,9 @@ -use crate::core::SourceLineNumber; +use crate::core::SourceLocation; use std::collections::BTreeMap; #[derive(Debug, Clone, Default)] pub struct ExecutionHistory { - pub lines: Vec, + pub lines: Vec, pub lines_cycles: Vec, // for each line, how many cycles it took } @@ -12,7 +12,7 @@ impl ExecutionHistory { Self::default() } - pub fn add_line(&mut self, location: SourceLineNumber, cycles: usize) { + pub fn add_line(&mut self, location: SourceLocation, cycles: usize) { self.lines.push(location); self.lines_cycles.push(cycles); } diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 65f4eeeb..e3e89644 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -1,16 +1,15 @@ //! VM execution runner use crate::core::{ - DIMENSION, F, NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, POSEIDON_16_NULL_HASH_PTR, POSEIDON_24_NULL_HASH_PTR, - VECTOR_LEN, ZERO_VEC_PTR, + DIMENSION, F, FileId, NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, POSEIDON_16_NULL_HASH_PTR, + POSEIDON_24_NULL_HASH_PTR, VECTOR_LEN, ZERO_VEC_PTR, }; use crate::diagnostics::{ExecutionResult, MemoryProfile, RunnerError, memory_profiling_report}; use crate::execution::{ExecutionHistory, Memory}; use crate::isa::Bytecode; use crate::isa::instruction::InstructionContext; use crate::{ - ALL_TABLES, CodeAddress, ENDING_PC, HintExecutionContext, N_TABLES, STARTING_PC, SourceLineNumber, Table, - TableTrace, + ALL_TABLES, CodeAddress, ENDING_PC, HintExecutionContext, N_TABLES, STARTING_PC, SourceLocation, Table, TableTrace, }; use multilinear_toolkit::prelude::*; use std::collections::{BTreeMap, BTreeSet, VecDeque}; @@ -77,9 +76,10 @@ pub fn execute_bytecode( println!( "\n{}", crate::diagnostics::pretty_stack_trace( - &bytecode.program, + &bytecode.source_code, latest_instructions, &bytecode.function_locations, + &bytecode.filepaths, last_pc ) ); @@ -92,7 +92,7 @@ pub fn execute_bytecode( panic!("Error during bytecode execution: {err}"); }); if profiling { - print_line_cycle_counts(instruction_history); + print_line_cycle_counts(instruction_history, &bytecode.filepaths); print_instruction_cycle_counts(bytecode, result.pcs.clone()); if let Some(ref mem_profile) = result.memory_profile { print!("{}", memory_profiling_report(mem_profile)); @@ -101,17 +101,18 @@ pub fn execute_bytecode( result } -fn print_line_cycle_counts(history: ExecutionHistory) { +fn print_line_cycle_counts(history: ExecutionHistory, filepaths: &BTreeMap) { println!("Line by line cycle counts"); println!("=========================\n"); - let mut gross_cycle_counts: BTreeMap = BTreeMap::new(); + let mut gross_cycle_counts: BTreeMap = BTreeMap::new(); for (line, cycle_count) in history.lines.iter().zip(history.lines_cycles.iter()) { let prev_count = gross_cycle_counts.get(line).unwrap_or(&0); gross_cycle_counts.insert(*line, *prev_count + cycle_count); } - for (line, cycle_count) in gross_cycle_counts.iter() { - println!("line {line}: {cycle_count} cycles"); + for (location, cycle_count) in gross_cycle_counts.iter() { + let filepath = filepaths.get(&location.file_id).expect("Unmapped FileId"); + println!("{filepath}:{}: {cycle_count} cycles", location.line_number); } println!(); } diff --git a/crates/lean_vm/src/isa/bytecode.rs b/crates/lean_vm/src/isa/bytecode.rs index b557f544..9529184c 100644 --- a/crates/lean_vm/src/isa/bytecode.rs +++ b/crates/lean_vm/src/isa/bytecode.rs @@ -1,6 +1,6 @@ //! Bytecode representation and management -use crate::{CodeAddress, Hint}; +use crate::{CodeAddress, FileId, FunctionName, Hint, SourceLocation}; use super::Instruction; use std::collections::BTreeMap; @@ -13,8 +13,9 @@ pub struct Bytecode { pub hints: BTreeMap>, // pc -> hints pub starting_frame_memory: usize, // debug - pub program: String, - pub function_locations: BTreeMap, + pub function_locations: BTreeMap, + pub filepaths: BTreeMap, + pub source_code: BTreeMap, } impl Display for Bytecode { diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index f8e9476a..b17fe421 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -1,4 +1,4 @@ -use crate::core::{F, LOG_VECTOR_LEN, Label, SourceLineNumber, VECTOR_LEN}; +use crate::core::{F, LOG_VECTOR_LEN, Label, SourceLocation, VECTOR_LEN}; use crate::diagnostics::{MemoryObject, MemoryObjectType, MemoryProfile, RunnerError}; use crate::execution::{ExecutionHistory, Memory}; use crate::isa::operands::MemOrConstant; @@ -61,7 +61,7 @@ pub enum Hint { /// Report source code location for debugging LocationReport { /// Source code location - location: SourceLineNumber, + location: SourceLocation, }, /// Jump destination label (for debugging purposes) Label { @@ -73,7 +73,7 @@ pub enum Hint { size: usize, }, /// Assert a boolean expression for debugging purposes - DebugAssert(BooleanExpr, SourceLineNumber), + DebugAssert(BooleanExpr, SourceLocation), } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -258,7 +258,7 @@ impl Hint { ); } } - Self::DebugAssert(bool_expr, line_number) => { + Self::DebugAssert(bool_expr, location) => { let left = bool_expr.left.read_value(ctx.memory, ctx.fp)?; let right = bool_expr.right.read_value(ctx.memory, ctx.fp)?; let condition_holds = match bool_expr.kind { @@ -269,7 +269,7 @@ impl Hint { if !condition_holds { return Err(RunnerError::DebugAssertFailed( format!("{} {} {}", left, bool_expr.kind, right), - *line_number, + *location, )); } } @@ -337,8 +337,11 @@ impl Display for Hint { Self::Inverse { arg, res_offset } => { write!(f, "m[fp + {res_offset}] = inverse({arg})") } - Self::LocationReport { location: line_number } => { - write!(f, "source line number: {line_number}") + Self::LocationReport { + location: SourceLocation { file_id, line_number }, + } => { + // TODO: make a pretty-print method which shows the filepath instead of file_id + write!(f, "source location: {file_id}:{line_number}") } Self::Label { label } => { write!(f, "label: {label}") diff --git a/crates/rec_aggregation/src/whir_recursion.rs b/crates/rec_aggregation/src/whir_recursion.rs index 7c9615d5..a9225ddd 100644 --- a/crates/rec_aggregation/src/whir_recursion.rs +++ b/crates/rec_aggregation/src/whir_recursion.rs @@ -19,7 +19,8 @@ const NUM_VARIABLES: usize = 25; pub fn run_whir_recursion_benchmark(tracing: bool, n_recursions: usize) { let src_file = Path::new(env!("CARGO_MANIFEST_DIR")).join("whir_recursion.snark"); - let mut program_str = std::fs::read_to_string(src_file).unwrap(); + let mut program_str = std::fs::read_to_string(src_file.clone()).unwrap(); + let filepath_str = src_file.to_str().unwrap(); let recursion_config_builder = WhirConfigBuilder { max_num_variables_to_send_coeffs: 6, security_level: 128, @@ -128,7 +129,7 @@ pub fn run_whir_recursion_benchmark(tracing: bool, n_recursions: usize) { utils::init_tracing(); } - let bytecode = compile_program(program_str); + let bytecode = compile_program(filepath_str, program_str); let mut merkle_path_hints = VecDeque::new(); for _ in 0..n_recursions { diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index c92e2ad0..e79a2a85 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -98,8 +98,9 @@ impl XmssAggregationProgram { #[instrument(skip_all)] fn compile_xmss_aggregation_program() -> XmssAggregationProgram { let src_file = Path::new(env!("CARGO_MANIFEST_DIR")).join("xmss_aggregate.snark"); - let program_str = std::fs::read_to_string(src_file).unwrap(); - let bytecode = compile_program(program_str); + let program_str = std::fs::read_to_string(src_file.clone()).unwrap(); + let filepath_str = src_file.to_str().unwrap(); + let bytecode = compile_program(filepath_str, program_str); let default_no_vec_mem = exec_phony_xmss(&bytecode, &[]).no_vec_runtime_memory; let mut no_vec_mem_per_log_lifetime = vec![]; for log_lifetime in XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME {