Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 48 additions & 16 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
Expression, Function, Line, Program, Scope, SimpleExpr, Var,
},
};
use lean_vm::{Boolean, BooleanExpr, SourceLineNumber, Table, TableT};
use lean_vm::{Boolean, BooleanExpr, FileId, SourceLineNumber, SourceLocation, Table, TableT};
use std::{
collections::{BTreeMap, BTreeSet},
fmt::{Display, Formatter},
Expand All @@ -21,6 +21,7 @@ pub struct SimpleProgram {
#[derive(Debug, Clone)]
pub struct SimpleFunction {
pub name: String,
pub file_id: FileId,
pub arguments: Vec<Var>,
pub n_returned_vars: usize,
pub instructions: Vec<SimpleLine>,
Expand Down Expand Up @@ -146,7 +147,7 @@ pub enum SimpleLine {
},
// noop, debug purpose only
LocationReport {
location: SourceLineNumber,
location: SourceLocation,
},
DebugAssert(BooleanExpr<SimpleExpr>, SourceLineNumber),
}
Expand Down Expand Up @@ -190,6 +191,7 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram {
let mut array_manager = ArrayManager::default();
let simplified_instructions = simplify_lines(
&program.functions,
func.file_id,
func.n_returned_vars,
&func.body,
&mut counters,
Expand All @@ -209,6 +211,7 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram {
.collect::<Vec<_>>();
let simplified_function = SimpleFunction {
name: name.clone(),
file_id: func.file_id,
arguments,
n_returned_vars: func.n_returned_vars,
instructions: simplified_instructions,
Expand Down Expand Up @@ -600,6 +603,7 @@ impl ArrayManager {
#[allow(clippy::too_many_arguments)]
fn simplify_lines(
functions: &BTreeMap<String, Function>,
file_id: FileId,
n_returned_vars: usize,
lines: &[Line],
counters: &mut Counters,
Expand All @@ -622,6 +626,7 @@ fn simplify_lines(
assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0");
simple_arms.push(simplify_lines(
functions,
file_id,
n_returned_vars,
statements,
counters,
Expand Down Expand Up @@ -812,6 +817,7 @@ fn simplify_lines(
let mut array_manager_then = array_manager.clone();
let then_branch_simplified = simplify_lines(
functions,
file_id,
n_returned_vars,
then_branch,
counters,
Expand All @@ -826,6 +832,7 @@ fn simplify_lines(

let else_branch_simplified = simplify_lines(
functions,
file_id,
n_returned_vars,
else_branch,
counters,
Expand Down Expand Up @@ -874,6 +881,7 @@ fn simplify_lines(
array_manager.valid.clear();
let simplified_body = simplify_lines(
functions,
file_id,
0,
body,
counters,
Expand Down Expand Up @@ -938,7 +946,10 @@ fn simplify_lines(
// Create recursive function body
let recursive_func = create_recursive_function(
func_name.clone(),
*line_number,
SourceLocation {
line_number: *line_number,
file_id,
},
func_args,
iterator.clone(),
end_simplified,
Expand Down Expand Up @@ -1122,7 +1133,12 @@ fn simplify_lines(
res.push(SimpleLine::Panic);
}
Line::LocationReport { location } => {
res.push(SimpleLine::LocationReport { location: *location });
res.push(SimpleLine::LocationReport {
location: SourceLocation {
line_number: *location,
file_id,
},
});
}
}
}
Expand Down Expand Up @@ -1161,10 +1177,7 @@ fn simplify_expr(
);
}
}
panic!(
"Const array '{}' can only be accessed with compile-time constant indices",
array_var
);
panic!("Const array '{array_var}' can only be accessed with compile-time constant indices",);
}

if let SimpleExpr::Var(array_var) = array
Expand Down Expand Up @@ -1672,7 +1685,7 @@ fn handle_array_assignment(

fn create_recursive_function(
name: String,
line_number: SourceLineNumber,
location: SourceLocation,
args: Vec<Var>,
iterator: Var,
end: SimpleExpr,
Expand All @@ -1696,7 +1709,7 @@ fn create_recursive_function(
function_name: name.clone(),
args: recursive_args,
return_data: vec![],
line_number,
line_number: location.line_number,
});
body.push(SimpleLine::FunctionRet { return_data: vec![] });

Expand All @@ -1713,12 +1726,13 @@ fn create_recursive_function(
condition: diff_var.into(),
then_branch: body,
else_branch: vec![SimpleLine::FunctionRet { return_data: vec![] }],
line_number,
line_number: location.line_number,
},
];

SimpleFunction {
name,
file_id: location.file_id,
arguments: args,
n_returned_vars: 0,
instructions,
Expand Down Expand Up @@ -2109,6 +2123,7 @@ fn handle_const_arguments(program: &mut Program) -> bool {
for func in program.functions.values_mut() {
if !func.has_const_arguments() {
any_changes |= handle_const_arguments_helper(
func.file_id,
&mut func.body,
&constant_functions,
&mut new_functions,
Expand All @@ -2133,6 +2148,7 @@ fn handle_const_arguments(program: &mut Program) -> bool {
if let Some(func) = new_functions.get_mut(&name) {
let initial_count = additional_functions.len();
handle_const_arguments_helper(
func.file_id,
&mut func.body,
&constant_functions,
&mut additional_functions,
Expand Down Expand Up @@ -2169,6 +2185,7 @@ fn handle_const_arguments(program: &mut Program) -> bool {
}

fn handle_const_arguments_helper(
file_id: FileId,
lines: &mut [Line],
constant_functions: &BTreeMap<String, Function>,
new_functions: &mut BTreeMap<String, Function>,
Expand Down Expand Up @@ -2228,6 +2245,7 @@ fn handle_const_arguments_helper(
const_funct_name.clone(),
Function {
name: const_funct_name,
file_id,
arguments: func
.arguments
.iter()
Expand All @@ -2247,15 +2265,29 @@ fn handle_const_arguments_helper(
else_branch,
..
} => {
changed |= handle_const_arguments_helper(then_branch, constant_functions, new_functions, const_arrays);
changed |= handle_const_arguments_helper(else_branch, constant_functions, new_functions, const_arrays);
changed |= handle_const_arguments_helper(
file_id,
then_branch,
constant_functions,
new_functions,
const_arrays,
);
changed |= handle_const_arguments_helper(
file_id,
else_branch,
constant_functions,
new_functions,
const_arrays,
);
}
Line::ForLoop { body, unroll: _, .. } => {
changed |= handle_const_arguments_helper(body, constant_functions, new_functions, const_arrays);
// TODO we should unroll before const arguments handling
handle_const_arguments_helper(file_id, body, constant_functions, new_functions, const_arrays);
}
Line::Match { arms, .. } => {
for (_, arm) in arms {
changed |= handle_const_arguments_helper(arm, constant_functions, new_functions, const_arrays);
changed |=
handle_const_arguments_helper(file_id, arm, constant_functions, new_functions, const_arrays);
}
}
_ => {}
Expand Down Expand Up @@ -2565,7 +2597,7 @@ impl SimpleLine {
Self::Panic => "panic".to_string(),
Self::LocationReport { .. } => Default::default(),
Self::DebugAssert(bool, _) => {
format!("debug_assert({})", bool)
format!("debug_assert({bool})")
}
};
format!("{spaces}{line_str}")
Expand Down
29 changes: 22 additions & 7 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ fn compile_function(
compiler.args_count = function.arguments.len();

compile_lines(
function.file_id,
&Label::function(function.name.clone()),
&function.instructions,
compiler,
Expand All @@ -153,6 +154,7 @@ fn compile_function(
}

fn compile_lines(
file_id: FileId,
function_name: &Label,
lines: &[SimpleLine],
compiler: &mut Compiler,
Expand Down Expand Up @@ -220,7 +222,8 @@ fn compile_lines(
for arm in arms.iter() {
compiler.stack_pos = saved_stack_pos;
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());
let arm_instructions = compile_lines(function_name, arm, compiler, Some(end_label.clone()))?;
let arm_instructions =
compile_lines(file_id, function_name, arm, compiler, Some(end_label.clone()))?;
compiled_arms.push(arm_instructions);
compiler.stack_frame_layout.scopes.pop();
new_stack_pos = new_stack_pos.max(compiler.stack_pos);
Expand Down Expand Up @@ -257,7 +260,7 @@ fn compile_lines(
updated_fp: None,
});

let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?;
let remaining = compile_lines(file_id, function_name, &lines[i + 1..], compiler, final_jump)?;
compiler.bytecode.insert(end_label, remaining);

compiler.stack_frame_layout.scopes.pop();
Expand Down Expand Up @@ -349,22 +352,24 @@ fn compile_lines(
let saved_stack_pos = compiler.stack_pos;

compiler.stack_frame_layout.scopes.push(ScopeLayout::default());
let then_instructions = compile_lines(function_name, then_branch, compiler, Some(end_label.clone()))?;
let then_instructions =
compile_lines(file_id, function_name, then_branch, compiler, Some(end_label.clone()))?;

let then_stack_pos = compiler.stack_pos;
compiler.stack_pos = saved_stack_pos;
compiler.stack_frame_layout.scopes.pop();
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());

let else_instructions = compile_lines(function_name, else_branch, compiler, Some(end_label.clone()))?;
let else_instructions =
compile_lines(file_id, function_name, else_branch, compiler, Some(end_label.clone()))?;

compiler.bytecode.insert(if_label, then_instructions);
compiler.bytecode.insert(else_label, else_instructions);

compiler.stack_frame_layout.scopes.pop();
compiler.stack_pos = compiler.stack_pos.max(then_stack_pos);

let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?;
let remaining = compile_lines(file_id, function_name, &lines[i + 1..], compiler, final_jump)?;
compiler.bytecode.insert(end_label, remaining);
// It is not necessary to update compiler.stack_size here because the preceding call to
// compile_lines should have done so.
Expand Down Expand Up @@ -438,7 +443,13 @@ fn compile_lines(
});
}

instructions.extend(compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?);
instructions.extend(compile_lines(
file_id,
function_name,
&lines[i + 1..],
compiler,
final_jump,
)?);

instructions
};
Expand Down Expand Up @@ -584,7 +595,11 @@ fn compile_lines(
left: IntermediateValue::from_simple_expr(&boolean.left, compiler),
right: IntermediateValue::from_simple_expr(&boolean.right, compiler),
};
instructions.push(IntermediateInstruction::DebugAssert(boolean_simplified, *line_number));
let location = SourceLocation {
file_id,
line_number: *line_number,
};
instructions.push(IntermediateInstruction::DebugAssert(boolean_simplified, location));
}
}
}
Expand Down
13 changes: 9 additions & 4 deletions crates/lean_compiler/src/c_compile_final.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ struct Compiler {

pub fn compile_to_low_level_bytecode(
mut intermediate_bytecode: IntermediateBytecode,
program: String,
function_locations: BTreeMap<usize, String>,
function_locations: BTreeMap<SourceLocation, FunctionName>,
source_code: BTreeMap<FileId, String>,
filepaths: BTreeMap<FileId, String>,
) -> Result<Bytecode, String> {
intermediate_bytecode.bytecode.insert(
Label::EndProgram,
Expand All @@ -45,7 +46,10 @@ pub fn compile_to_low_level_bytecode(
}],
);

let starting_frame_memory = *intermediate_bytecode.memory_size_per_function.get("main").unwrap();
let starting_frame_memory = *intermediate_bytecode
.memory_size_per_function
.get("main")
.expect("Missing main function");

let mut hints = BTreeMap::new();
let mut label_to_pc = BTreeMap::new();
Expand Down Expand Up @@ -145,8 +149,9 @@ pub fn compile_to_low_level_bytecode(
instructions,
hints,
starting_frame_memory,
program,
function_locations,
source_code,
filepaths,
})
}

Expand Down
9 changes: 7 additions & 2 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
WHITESPACE = _{ " " | "\t" | "\n" | "\r" }

// Program structure
program = { SOI ~ constant_declaration* ~ function+ ~ EOI }
program = { SOI ~ import_statement* ~ constant_declaration* ~ function* ~ EOI }

// Imports
import_statement = { "import" ~ filepath ~ ";" }

// Constants
constant_declaration = { "const" ~ identifier ~ "=" ~ (array_literal | expression) ~ ";" }
Expand Down Expand Up @@ -108,4 +111,6 @@ constant_value = { number | "public_input_start" | "pointer_to_zero_vector" | "p

// Lexical elements
identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
number = @{ ASCII_DIGIT+ }
number = @{ ASCII_DIGIT+ }
filepath = { "\"" ~ filepath_character* ~ "\"" }
filepath_character = { ASCII_ALPHANUMERIC | "-" | "_" | " " | "." | "+" | "/" }
Loading