diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs deleted file mode 100644 index 4559ebb8..00000000 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ /dev/null @@ -1,2218 +0,0 @@ -use crate::{ - Counter, F, - ir::HighLevelOperation, - lang::{ - Boolean, ConstExpression, ConstMallocLabel, Expression, Function, Line, Program, - SimpleExpr, Var, - }, - precompiles::Precompile, -}; -use lean_vm::SourceLineNumber; -use std::{ - collections::{BTreeMap, BTreeSet}, - fmt::{Display, Formatter}, -}; -use utils::ToUsize; - -#[derive(Debug, Clone)] -pub struct SimpleProgram { - pub functions: BTreeMap, -} - -#[derive(Debug, Clone)] -pub struct SimpleFunction { - pub name: String, - pub arguments: Vec, - pub n_returned_vars: usize, - pub instructions: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum VarOrConstMallocAccess { - Var(Var), - ConstMallocAccess { - malloc_label: ConstMallocLabel, - offset: ConstExpression, - }, -} - -impl From for SimpleExpr { - fn from(var_or_const: VarOrConstMallocAccess) -> Self { - match var_or_const { - VarOrConstMallocAccess::Var(var) => Self::Var(var), - VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - } => Self::ConstMallocAccess { - malloc_label, - offset, - }, - } - } -} - -impl TryInto for SimpleExpr { - type Error = (); - - fn try_into(self) -> Result { - match self { - Self::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), - Self::ConstMallocAccess { - malloc_label, - offset, - } => Ok(VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - }), - _ => Err(()), - } - } -} - -impl From for VarOrConstMallocAccess { - fn from(var: Var) -> Self { - Self::Var(var) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum SimpleLine { - Match { - value: SimpleExpr, - arms: Vec>, // patterns = 0, 1, ... - }, - Assignment { - var: VarOrConstMallocAccess, - operation: HighLevelOperation, - arg0: SimpleExpr, - arg1: SimpleExpr, - }, - RawAccess { - res: SimpleExpr, - index: SimpleExpr, - shift: ConstExpression, - }, // res = memory[index + shift] - IfNotZero { - condition: SimpleExpr, - then_branch: Vec, - else_branch: Vec, - }, - FunctionCall { - function_name: String, - args: Vec, - return_data: Vec, - }, - FunctionRet { - return_data: Vec, - }, - Precompile { - precompile: Precompile, - args: Vec, - }, - Panic, - // Hints - DecomposeBits { - var: Var, // a pointer to 31 * len(to_decompose) field elements, containing the bits of "to_decompose" - to_decompose: Vec, - label: ConstMallocLabel, - }, - CounterHint { - var: Var, - }, - Print { - line_info: String, - content: Vec, - }, - HintMAlloc { - var: Var, - size: SimpleExpr, - vectorized: bool, - vectorized_len: SimpleExpr, - }, - ConstMalloc { - // always not vectorized - var: Var, - size: ConstExpression, - label: ConstMallocLabel, - }, - // noop, debug purpose only - LocationReport { - location: SourceLineNumber, - }, -} - -pub fn simplify_program(mut program: Program) -> SimpleProgram { - handle_inlined_functions(&mut program); - handle_const_arguments(&mut program); - let mut new_functions = BTreeMap::new(); - let mut counters = Counters::default(); - let mut const_malloc = ConstMalloc::default(); - for (name, func) in &program.functions { - let mut array_manager = ArrayManager::default(); - let simplified_instructions = simplify_lines( - &func.body, - &mut counters, - &mut new_functions, - false, - &mut array_manager, - &mut const_malloc, - ); - let arguments = func - .arguments - .iter() - .map(|(v, is_const)| { - assert!(!is_const,); - v.clone() - }) - .collect::>(); - new_functions.insert( - name.clone(), - SimpleFunction { - name: name.clone(), - arguments, - n_returned_vars: func.n_returned_vars, - instructions: simplified_instructions, - }, - ); - const_malloc.map.clear(); - } - SimpleProgram { - functions: new_functions, - } -} - -#[derive(Debug, Clone, Default)] -struct Counters { - aux_vars: usize, - loops: usize, - unrolls: usize, -} - -#[derive(Debug, Clone, Default)] -struct ArrayManager { - counter: usize, - aux_vars: BTreeMap<(SimpleExpr, Expression), Var>, // (array, index) -> aux_var - valid: BTreeSet, // currently valid aux vars -} - -#[derive(Debug, Clone, Default)] -pub struct ConstMalloc { - counter: usize, - map: BTreeMap, - forbidden_vars: BTreeSet, // vars shared between branches of an if/else -} - -impl ArrayManager { - fn get_aux_var(&mut self, array: &SimpleExpr, index: &Expression) -> Var { - if let Some(var) = self.aux_vars.get(&(array.clone(), index.clone())) { - return var.clone(); - } - let new_var = format!("@arr_aux_{}", self.counter); - self.counter += 1; - self.aux_vars - .insert((array.clone(), index.clone()), new_var.clone()); - new_var - } -} - -fn simplify_lines( - lines: &[Line], - counters: &mut Counters, - new_functions: &mut BTreeMap, - in_a_loop: bool, - array_manager: &mut ArrayManager, - const_malloc: &mut ConstMalloc, -) -> Vec { - let mut res = Vec::new(); - for line in lines { - match line { - Line::Match { value, arms } => { - let simple_value = - simplify_expr(value, &mut res, counters, array_manager, const_malloc); - let mut simple_arms = vec![]; - for (i, (pattern, statements)) in arms.iter().enumerate() { - assert_eq!( - *pattern, i, - "match patterns should be consecutive, starting from 0" - ); - simple_arms.push(simplify_lines( - statements, - counters, - new_functions, - in_a_loop, - array_manager, - const_malloc, - )); - } - res.push(SimpleLine::Match { - value: simple_value, - arms: simple_arms, - }); - } - Line::Assignment { var, value } => match value { - Expression::Value(value) => { - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: HighLevelOperation::Add, - arg0: value.clone(), - arg1: SimpleExpr::zero(), - }); - } - Expression::ArrayAccess { array, index } => { - handle_array_assignment( - counters, - &mut res, - array.clone(), - index, - ArrayAccessType::VarIsAssigned(var.clone()), - array_manager, - const_malloc, - ); - } - Expression::Binary { - left, - operation, - right, - } => { - let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: *operation, - arg0: left, - arg1: right, - }); - } - Expression::Log2Ceil { .. } => unreachable!(), - }, - Line::ArrayAssign { - array, - index, - value, - } => { - handle_array_assignment( - counters, - &mut res, - array.clone(), - index, - ArrayAccessType::ArrayIsAssigned(value.clone()), - array_manager, - const_malloc, - ); - } - Line::Assert(boolean) => match boolean { - Boolean::Different { left, right } => { - let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - let diff_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: left, - arg1: right, - }); - res.push(SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: vec![], - else_branch: vec![SimpleLine::Panic], - }); - } - Boolean::Equal { left, right } => { - let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - let (var, other) = if let Ok(left) = left.clone().try_into() { - (left, right) - } else if let Ok(right) = right.clone().try_into() { - (right, left) - } else { - unreachable!("Weird: {:?}, {:?}", left, right) - }; - res.push(SimpleLine::Assignment { - var, - operation: HighLevelOperation::Add, - arg0: other, - arg1: SimpleExpr::zero(), - }); - } - }, - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - // Transform if a == b then X else Y into if a != b then Y else X - - let (left, right, then_branch, else_branch) = match condition { - Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched - Boolean::Different { left, right } => (left, right, then_branch, else_branch), - }; - - let left_simplified = - simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right_simplified = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - - let diff_var = format!("@diff_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: left_simplified, - arg1: right_simplified, - }); - - let forbidden_vars_before = const_malloc.forbidden_vars.clone(); - - let then_internal_vars = find_variable_usage(then_branch).0; - let else_internal_vars = find_variable_usage(else_branch).0; - let new_forbidden_vars = then_internal_vars - .intersection(&else_internal_vars) - .cloned() - .collect::>(); - - const_malloc.forbidden_vars.extend(new_forbidden_vars); - - let mut array_manager_then = array_manager.clone(); - let then_branch_simplified = simplify_lines( - then_branch, - counters, - new_functions, - in_a_loop, - &mut array_manager_then, - const_malloc, - ); - let mut array_manager_else = array_manager_then.clone(); - array_manager_else.valid = array_manager.valid.clone(); // Crucial: remove the access added in the IF branch - - let else_branch_simplified = simplify_lines( - else_branch, - counters, - new_functions, - in_a_loop, - &mut array_manager_else, - const_malloc, - ); - - const_malloc.forbidden_vars = forbidden_vars_before; - - *array_manager = array_manager_else.clone(); - // keep the intersection both branches - array_manager.valid = array_manager - .valid - .intersection(&array_manager_then.valid) - .cloned() - .collect(); - - res.push(SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: then_branch_simplified, - else_branch: else_branch_simplified, - }); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev, - unroll, - } => { - if *unroll { - let (internal_variables, _) = find_variable_usage(body); - let mut unrolled_lines = Vec::new(); - let start_evaluated = start.naive_eval().unwrap().to_usize(); - let end_evaluated = end.naive_eval().unwrap().to_usize(); - let unroll_index = counters.unrolls; - counters.unrolls += 1; - - let mut range = (start_evaluated..end_evaluated).collect::>(); - if *rev { - range.reverse(); - } - - for i in range { - let mut body_copy = body.clone(); - replace_vars_for_unroll( - &mut body_copy, - iterator, - unroll_index, - i, - &internal_variables, - ); - unrolled_lines.extend(simplify_lines( - &body_copy, - counters, - new_functions, - in_a_loop, - array_manager, - const_malloc, - )); - } - res.extend(unrolled_lines); - continue; - } - - if *rev { - unimplemented!("Reverse for non-unrolled loops are not implemented yet"); - } - - let mut loop_const_malloc = ConstMalloc { - counter: const_malloc.counter, - ..ConstMalloc::default() - }; - let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); - array_manager.valid.clear(); - let simplified_body = simplify_lines( - body, - counters, - new_functions, - true, - array_manager, - &mut loop_const_malloc, - ); - const_malloc.counter = loop_const_malloc.counter; - array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars - - let func_name = format!("@loop_{}", counters.loops); - counters.loops += 1; - - // Find variables used inside loop but defined outside - let (_, mut external_vars) = find_variable_usage(body); - - // Include variables in start/end - for expr in [start, end] { - for var in vars_in_expression(expr) { - external_vars.insert(var); - } - } - external_vars.remove(iterator); // Iterator is internal to loop - - let mut external_vars: Vec<_> = external_vars.into_iter().collect(); - - let start_simplified = - simplify_expr(start, &mut res, counters, array_manager, const_malloc); - let end_simplified = - simplify_expr(end, &mut res, counters, array_manager, const_malloc); - - for (simplified, original) in [ - (start_simplified.clone(), start.clone()), - (end_simplified.clone(), end.clone()), - ] { - if !matches!(original, Expression::Value(_)) { - // the simplified var is auxiliary - if let SimpleExpr::Var(var) = simplified { - external_vars.push(var); - } - } - } - - // Create function arguments: iterator + external variables - let mut func_args = vec![iterator.clone()]; - func_args.extend(external_vars.clone()); - - // Create recursive function body - let recursive_func = create_recursive_function( - func_name.clone(), - func_args, - iterator.clone(), - end_simplified, - simplified_body, - &external_vars, - ); - new_functions.insert(func_name.clone(), recursive_func); - - // Replace loop with initial function call - let mut call_args = vec![start_simplified]; - call_args.extend(external_vars.iter().map(|v| v.clone().into())); - - res.push(SimpleLine::FunctionCall { - function_name: func_name, - args: call_args, - return_data: vec![], - }); - } - Line::FunctionCall { - function_name, - args, - return_data, - } => { - let simplified_args = args - .iter() - .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::FunctionCall { - function_name: function_name.clone(), - args: simplified_args, - return_data: return_data.clone(), - }); - } - Line::FunctionRet { return_data } => { - assert!( - !in_a_loop, - "Function return inside a loop is not currently supported" - ); - let simplified_return_data = return_data - .iter() - .map(|ret| simplify_expr(ret, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::FunctionRet { - return_data: simplified_return_data, - }); - } - Line::Precompile { precompile, args } => { - let simplified_args = args - .iter() - .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::Precompile { - precompile: precompile.clone(), - args: simplified_args, - }); - } - Line::Print { line_info, content } => { - let simplified_content = content - .iter() - .map(|var| simplify_expr(var, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::Print { - line_info: line_info.clone(), - content: simplified_content, - }); - } - Line::Break => { - assert!(in_a_loop, "Break statement outside of a loop"); - res.push(SimpleLine::FunctionRet { - return_data: vec![], - }); - } - Line::MAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - let simplified_size = - simplify_expr(size, &mut res, counters, array_manager, const_malloc); - let simplified_vectorized_len = simplify_expr( - vectorized_len, - &mut res, - counters, - array_manager, - const_malloc, - ); - if simplified_size.is_constant() - && !*vectorized - && const_malloc.forbidden_vars.contains(var) - { - println!( - "TODO: Optimization missed: Requires to align const malloc in if/else branches" - ); - } - match simplified_size { - SimpleExpr::Constant(const_size) - if !*vectorized && !const_malloc.forbidden_vars.contains(var) => - { - // TODO do this optimization even if we are in an if/else branch - let label = const_malloc.counter; - const_malloc.counter += 1; - const_malloc.map.insert(var.clone(), label); - res.push(SimpleLine::ConstMalloc { - var: var.clone(), - size: const_size, - label, - }); - } - _ => { - res.push(SimpleLine::HintMAlloc { - var: var.clone(), - size: simplified_size, - vectorized: *vectorized, - vectorized_len: simplified_vectorized_len, - }); - } - } - } - Line::DecomposeBits { var, to_decompose } => { - assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); - let simplified_to_decompose = to_decompose - .iter() - .map(|expr| { - simplify_expr(expr, &mut res, counters, array_manager, const_malloc) - }) - .collect::>(); - let label = const_malloc.counter; - const_malloc.counter += 1; - const_malloc.map.insert(var.clone(), label); - res.push(SimpleLine::DecomposeBits { - var: var.clone(), - to_decompose: simplified_to_decompose, - label, - }); - } - Line::CounterHint { var } => { - res.push(SimpleLine::CounterHint { var: var.clone() }); - } - Line::Panic => { - res.push(SimpleLine::Panic); - } - Line::LocationReport { location } => { - res.push(SimpleLine::LocationReport { - location: *location, - }); - } - } - } - - res -} - -fn simplify_expr( - expr: &Expression, - lines: &mut Vec, - counters: &mut Counters, - array_manager: &mut ArrayManager, - const_malloc: &ConstMalloc, -) -> SimpleExpr { - match expr { - Expression::Value(value) => value.simplify_if_const(), - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array - && let Some(label) = const_malloc.map.get(array_var) - && let Ok(mut offset) = ConstExpression::try_from(*index.clone()) - { - offset = offset.try_naive_simplification(); - return SimpleExpr::ConstMallocAccess { - malloc_label: *label, - offset, - }; - } - - let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index] - - if !array_manager.valid.insert(aux_arr.clone()) { - return SimpleExpr::Var(aux_arr); - } - - handle_array_assignment( - counters, - lines, - array.clone(), - index, - ArrayAccessType::VarIsAssigned(aux_arr.clone()), - array_manager, - const_malloc, - ); - SimpleExpr::Var(aux_arr) - } - Expression::Binary { - left, - operation, - right, - } => { - let left_var = simplify_expr(left, lines, counters, array_manager, const_malloc); - let right_var = simplify_expr(right, lines, counters, array_manager, const_malloc); - - if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = - (&left_var, &right_var) - { - return SimpleExpr::Constant(ConstExpression::Binary { - left: Box::new(left_cst.clone()), - operation: *operation, - right: Box::new(right_cst.clone()), - }); - } - - let aux_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - lines.push(SimpleLine::Assignment { - var: aux_var.clone().into(), - operation: *operation, - arg0: left_var, - arg1: right_var, - }); - SimpleExpr::Var(aux_var) - } - Expression::Log2Ceil { value } => { - let const_value = simplify_expr(value, lines, counters, array_manager, const_malloc) - .as_constant() - .unwrap(); - SimpleExpr::Constant(ConstExpression::Log2Ceil { - value: Box::new(const_value), - }) - } - } -} - -/// Returns (internal_vars, external_vars) -pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { - let mut internal_vars = BTreeSet::new(); - let mut external_vars = BTreeSet::new(); - - let on_new_expr = - |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - for var in vars_in_expression(expr) { - if !internal_vars.contains(&var) { - external_vars.insert(var); - } - } - }; - - let on_new_condition = - |condition: &Boolean, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; - on_new_expr(left, internal_vars, external_vars); - on_new_expr(right, internal_vars, external_vars); - }; - - for line in lines { - match line { - Line::Match { value, arms } => { - on_new_expr(value, &internal_vars, &mut external_vars); - for (_, statements) in arms { - let (stmt_internal, stmt_external) = find_variable_usage(statements); - internal_vars.extend(stmt_internal); - external_vars.extend( - stmt_external - .into_iter() - .filter(|v| !internal_vars.contains(v)), - ); - } - } - Line::Assignment { var, value } => { - on_new_expr(value, &internal_vars, &mut external_vars); - internal_vars.insert(var.clone()); - } - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - on_new_condition(condition, &internal_vars, &mut external_vars); - - let (then_internal, then_external) = find_variable_usage(then_branch); - let (else_internal, else_external) = find_variable_usage(else_branch); - - internal_vars.extend(then_internal.union(&else_internal).cloned()); - external_vars.extend( - then_external - .union(&else_external) - .filter(|v| !internal_vars.contains(*v)) - .cloned(), - ); - } - Line::FunctionCall { - args, return_data, .. - } => { - for arg in args { - on_new_expr(arg, &internal_vars, &mut external_vars); - } - internal_vars.extend(return_data.iter().cloned()); - } - Line::Assert(condition) => { - on_new_condition(condition, &internal_vars, &mut external_vars); - } - Line::FunctionRet { return_data } => { - for ret in return_data { - on_new_expr(ret, &internal_vars, &mut external_vars); - } - } - Line::MAlloc { var, size, .. } => { - on_new_expr(size, &internal_vars, &mut external_vars); - internal_vars.insert(var.clone()); - } - Line::Precompile { - precompile: _, - args, - } => { - for arg in args { - on_new_expr(arg, &internal_vars, &mut external_vars); - } - } - Line::Print { content, .. } => { - for var in content { - on_new_expr(var, &internal_vars, &mut external_vars); - } - } - Line::DecomposeBits { var, to_decompose } => { - for expr in to_decompose { - on_new_expr(expr, &internal_vars, &mut external_vars); - } - internal_vars.insert(var.clone()); - } - Line::CounterHint { var } => { - internal_vars.insert(var.clone()); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev: _, - unroll: _, - } => { - let (body_internal, body_external) = find_variable_usage(body); - internal_vars.extend(body_internal); - internal_vars.insert(iterator.clone()); - external_vars.extend(body_external.difference(&internal_vars).cloned()); - on_new_expr(start, &internal_vars, &mut external_vars); - on_new_expr(end, &internal_vars, &mut external_vars); - } - Line::ArrayAssign { - array, - index, - value, - } => { - on_new_expr(&array.clone().into(), &internal_vars, &mut external_vars); - on_new_expr(index, &internal_vars, &mut external_vars); - on_new_expr(value, &internal_vars, &mut external_vars); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } - - (internal_vars, external_vars) -} - -fn inline_simple_expr( - simple_expr: &mut SimpleExpr, - args: &BTreeMap, - inlining_count: usize, -) { - if let SimpleExpr::Var(var) = simple_expr { - if let Some(replacement) = args.get(var) { - *simple_expr = replacement.clone(); - } else { - *var = format!("@inlined_var_{inlining_count}_{var}"); - } - } -} - -fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining_count: usize) { - match expr { - Expression::Value(value) => { - inline_simple_expr(value, args, inlining_count); - } - Expression::ArrayAccess { array, index } => { - inline_simple_expr(array, args, inlining_count); - inline_expr(index, args, inlining_count); - } - Expression::Binary { left, right, .. } => { - inline_expr(left, args, inlining_count); - inline_expr(right, args, inlining_count); - } - Expression::Log2Ceil { value } => { - inline_expr(value, args, inlining_count); - } - } -} - -pub fn inline_lines( - lines: &mut Vec, - args: &BTreeMap, - res: &[Var], - inlining_count: usize, -) { - let inline_condition = |condition: &mut Boolean| { - let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; - inline_expr(left, args, inlining_count); - inline_expr(right, args, inlining_count); - }; - - let inline_internal_var = |var: &mut Var| { - assert!( - !args.contains_key(var), - "Variable {var} is both an argument and assigned in the inlined function" - ); - *var = format!("@inlined_var_{inlining_count}_{var}"); - }; - - let mut lines_to_replace = vec![]; - for (i, line) in lines.iter_mut().enumerate() { - match line { - Line::Match { value, arms } => { - inline_expr(value, args, inlining_count); - for (_, statements) in arms { - inline_lines(statements, args, res, inlining_count); - } - } - Line::Assignment { var, value } => { - inline_expr(value, args, inlining_count); - inline_internal_var(var); - } - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - inline_condition(condition); - - inline_lines(then_branch, args, res, inlining_count); - inline_lines(else_branch, args, res, inlining_count); - } - Line::FunctionCall { - args: func_args, - return_data, - .. - } => { - for arg in func_args { - inline_expr(arg, args, inlining_count); - } - for return_var in return_data { - inline_internal_var(return_var); - } - } - Line::Assert(condition) => { - inline_condition(condition); - } - Line::FunctionRet { return_data } => { - assert_eq!(return_data.len(), res.len()); - - for expr in return_data.iter_mut() { - inline_expr(expr, args, inlining_count); - } - lines_to_replace.push(( - i, - res.iter() - .zip(return_data) - .map(|(res_var, expr)| Line::Assignment { - var: res_var.clone(), - value: expr.clone(), - }) - .collect::>(), - )); - } - Line::MAlloc { var, size, .. } => { - inline_expr(size, args, inlining_count); - inline_internal_var(var); - } - Line::Precompile { - precompile: _, - args: precompile_args, - } => { - for arg in precompile_args { - inline_expr(arg, args, inlining_count); - } - } - Line::Print { content, .. } => { - for var in content { - inline_expr(var, args, inlining_count); - } - } - Line::DecomposeBits { var, to_decompose } => { - for expr in to_decompose { - inline_expr(expr, args, inlining_count); - } - inline_internal_var(var); - } - Line::CounterHint { var } => { - inline_internal_var(var); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev: _, - unroll: _, - } => { - inline_lines(body, args, res, inlining_count); - inline_internal_var(iterator); - inline_expr(start, args, inlining_count); - inline_expr(end, args, inlining_count); - } - Line::ArrayAssign { - array, - index, - value, - } => { - inline_simple_expr(array, args, inlining_count); - inline_expr(index, args, inlining_count); - inline_expr(value, args, inlining_count); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } - for (i, new_lines) in lines_to_replace.into_iter().rev() { - lines.splice(i..=i, new_lines); - } -} - -fn vars_in_expression(expr: &Expression) -> BTreeSet { - let mut vars = BTreeSet::new(); - match expr { - Expression::Value(value) => { - if let SimpleExpr::Var(var) = value { - vars.insert(var.clone()); - } - } - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array) = array { - vars.insert(array.clone()); - } - vars.extend(vars_in_expression(index)); - } - Expression::Binary { left, right, .. } => { - vars.extend(vars_in_expression(left)); - vars.extend(vars_in_expression(right)); - } - Expression::Log2Ceil { value } => { - vars.extend(vars_in_expression(value)); - } - } - vars -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ArrayAccessType { - VarIsAssigned(Var), // var = array[index] - ArrayIsAssigned(Expression), // array[index] = expr -} - -fn handle_array_assignment( - counters: &mut Counters, - res: &mut Vec, - array: SimpleExpr, - index: &Expression, - access_type: ArrayAccessType, - array_manager: &mut ArrayManager, - const_malloc: &ConstMalloc, -) { - let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc); - - if let SimpleExpr::Constant(offset) = simplified_index.clone() - && let SimpleExpr::Var(array_var) = &array - && let Some(label) = const_malloc.map.get(array_var) - && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { - left, - operation, - right, - }) = &access_type - { - let arg0 = simplify_expr(left, res, counters, array_manager, const_malloc); - let arg1 = simplify_expr(right, res, counters, array_manager, const_malloc); - res.push(SimpleLine::Assignment { - var: VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *label, - offset, - }, - operation: *operation, - arg0, - arg1, - }); - return; - } - - let value_simplified = match access_type { - ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), - ArrayAccessType::ArrayIsAssigned(expr) => { - simplify_expr(&expr, res, counters, array_manager, const_malloc) - } - }; - - // TODO opti: in some case we could use ConstMallocAccess - - let (index_var, shift) = match simplified_index { - SimpleExpr::Constant(c) => (array, c), - _ => { - // Create pointer variable: ptr = array + index - let ptr_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: ptr_var.clone().into(), - operation: HighLevelOperation::Add, - arg0: array, - arg1: simplified_index, - }); - (SimpleExpr::Var(ptr_var), ConstExpression::zero()) - } - }; - - res.push(SimpleLine::RawAccess { - res: value_simplified, - index: index_var, - shift, - }); -} - -fn create_recursive_function( - name: String, - args: Vec, - iterator: Var, - end: SimpleExpr, - mut body: Vec, - external_vars: &[Var], -) -> SimpleFunction { - // Add iterator increment - let next_iter = format!("@incremented_{iterator}"); - body.push(SimpleLine::Assignment { - var: next_iter.clone().into(), - operation: HighLevelOperation::Add, - arg0: iterator.clone().into(), - arg1: SimpleExpr::one(), - }); - - // Add recursive call - let mut recursive_args: Vec = vec![next_iter.into()]; - recursive_args.extend(external_vars.iter().map(|v| v.clone().into())); - - body.push(SimpleLine::FunctionCall { - function_name: name.clone(), - args: recursive_args, - return_data: vec![], - }); - body.push(SimpleLine::FunctionRet { - return_data: vec![], - }); - - let diff_var = format!("@diff_{iterator}"); - - let instructions = vec![ - SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: iterator.into(), - arg1: end, - }, - SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: body, - else_branch: vec![SimpleLine::FunctionRet { - return_data: vec![], - }], - }, - ]; - - SimpleFunction { - name, - arguments: args, - n_returned_vars: 0, - instructions, - } -} - -fn replace_vars_for_unroll_in_expr( - expr: &mut Expression, - iterator: &Var, - unroll_index: usize, - iterator_value: usize, - internal_vars: &BTreeSet, -) { - match expr { - Expression::Value(value_expr) => match value_expr { - SimpleExpr::Var(var) => { - if var == iterator { - *value_expr = SimpleExpr::Constant(ConstExpression::from(iterator_value)); - } else if internal_vars.contains(var) { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - } - } - SimpleExpr::Constant(_) | SimpleExpr::ConstMallocAccess { .. } => {} - }, - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } - } - - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Expression::Binary { left, right, .. } => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Expression::Log2Ceil { value } => { - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } -} - -fn replace_vars_for_unroll( - lines: &mut [Line], - iterator: &Var, - unroll_index: usize, - iterator_value: usize, - internal_vars: &BTreeSet, -) { - for line in lines { - match line { - Line::Match { value, arms } => { - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - for (_, statements) in arms { - replace_vars_for_unroll( - statements, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::Assignment { var, value } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::ArrayAssign { - // array[index] = value - array, - index, - value, - } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = - format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } - } - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::Assert(Boolean::Equal { left, right } | Boolean::Different { left, right }) => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::IfCondition { - condition: Boolean::Equal { left, right } | Boolean::Different { left, right }, - then_branch, - else_branch, - } => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - then_branch, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - else_branch, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::ForLoop { - iterator: other_iterator, - start, - end, - body, - rev: _, - unroll: _, - } => { - assert!(other_iterator != iterator); - *other_iterator = - format!("@unrolled_{unroll_index}_{iterator_value}_{other_iterator}"); - replace_vars_for_unroll_in_expr( - start, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - end, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - body, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::FunctionCall { - function_name: _, - args, - return_data, - } => { - // Function calls are not unrolled, so we don't need to change them - for arg in args { - replace_vars_for_unroll_in_expr( - arg, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - for ret in return_data { - *ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}"); - } - } - Line::FunctionRet { return_data } => { - for ret in return_data { - replace_vars_for_unroll_in_expr( - ret, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::Precompile { - precompile: _, - args, - } => { - for arg in args { - replace_vars_for_unroll_in_expr( - arg, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::Print { line_info, content } => { - // Print statements are not unrolled, so we don't need to change them - *line_info += &format!(" (unrolled {unroll_index} {iterator_value})"); - for var in content { - replace_vars_for_unroll_in_expr( - var, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::MAlloc { - var, - size, - vectorized: _, - vectorized_len, - } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr( - size, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - vectorized_len, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::DecomposeBits { var, to_decompose } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - for expr in to_decompose { - replace_vars_for_unroll_in_expr( - expr, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::CounterHint { var } => { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - } - Line::Break | Line::Panic | Line::LocationReport { .. } => {} - } - } -} - -fn handle_inlined_functions(program: &mut Program) { - let inlined_functions = program - .functions - .iter() - .filter(|(_, func)| func.inlined) - .map(|(name, func)| (name.clone(), func.clone())) - .collect::>(); - - for func in inlined_functions.values() { - assert!( - !func.has_const_arguments(), - "Inlined functions with constant arguments are not supported yet" - ); - } - - // Process inline functions iteratively to handle dependencies - // Repeat until all inline function calls are resolved - let mut max_iterations = 10; - while max_iterations > 0 { - let mut any_changes = false; - - // Process non-inlined functions - for func in program.functions.values_mut() { - if !func.inlined { - let mut counter1 = Counter::new(); - let mut counter2 = Counter::new(); - let old_body = func.body.clone(); - - handle_inlined_functions_helper( - &mut func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); - - if func.body != old_body { - any_changes = true; - } - } - } - - // Process inlined functions that may call other inlined functions - // We need to update them so that when they get inlined later, they don't have unresolved calls - for func in program.functions.values_mut() { - if func.inlined { - let mut counter1 = Counter::new(); - let mut counter2 = Counter::new(); - let old_body = func.body.clone(); - - handle_inlined_functions_helper( - &mut func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); - - if func.body != old_body { - any_changes = true; - } - } - } - - if !any_changes { - break; - } - - max_iterations -= 1; - } - - assert!( - max_iterations > 0, - "Too many iterations processing inline functions" - ); - - // Remove all inlined functions from the program (they've been inlined) - for func_name in inlined_functions.keys() { - program.functions.remove(func_name); - } -} - -fn handle_inlined_functions_helper( - lines: &mut Vec, - inlined_functions: &BTreeMap, - inlined_var_counter: &mut Counter, - total_inlined_counter: &mut Counter, -) { - for i in (0..lines.len()).rev() { - match &mut lines[i] { - Line::FunctionCall { - function_name, - args, - return_data, - } => { - if let Some(func) = inlined_functions.get(&*function_name) { - let mut inlined_lines = vec![]; - - let mut simplified_args = vec![]; - for arg in args { - if let Expression::Value(simple_expr) = arg { - simplified_args.push(simple_expr.clone()); - } else { - let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); - inlined_lines.push(Line::Assignment { - var: aux_var.clone(), - value: arg.clone(), - }); - simplified_args.push(SimpleExpr::Var(aux_var)); - } - } - assert_eq!(simplified_args.len(), func.arguments.len()); - let inlined_args = func - .arguments - .iter() - .zip(&simplified_args) - .map(|((var, _), expr)| (var.clone(), expr.clone())) - .collect::>(); - let mut func_body = func.body.clone(); - inline_lines( - &mut func_body, - &inlined_args, - return_data, - total_inlined_counter.next(), - ); - inlined_lines.extend(func_body); - - lines.remove(i); // remove the call to the inlined function - lines.splice(i..i, inlined_lines); - } - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - handle_inlined_functions_helper( - then_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - handle_inlined_functions_helper( - else_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - Line::ForLoop { - body, unroll: _, .. - } => { - handle_inlined_functions_helper( - body, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - Line::Match { arms, .. } => { - for (_, arm) in arms { - handle_inlined_functions_helper( - arm, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - } - _ => {} - } - } -} - -fn handle_const_arguments(program: &mut Program) { - let mut new_functions = BTreeMap::::new(); - let constant_functions = program - .functions - .iter() - .filter(|(_, func)| func.has_const_arguments()) - .map(|(name, func)| (name.clone(), func.clone())) - .collect::>(); - - // First pass: process non-const functions that call const functions - for func in program.functions.values_mut() { - if !func.has_const_arguments() { - handle_const_arguments_helper(&mut func.body, &constant_functions, &mut new_functions); - } - } - - // Process newly created const functions recursively until no more changes - let mut changed = true; - let mut const_depth = 0; - while changed { - changed = false; - const_depth += 1; - assert!(const_depth < 100, "Too many levels of constant arguments"); - let mut additional_functions = BTreeMap::new(); - - // Collect all function names to process - let function_names: Vec = new_functions.keys().cloned().collect(); - - for name in function_names { - if let Some(func) = new_functions.get_mut(&name) { - let initial_count = additional_functions.len(); - handle_const_arguments_helper( - &mut func.body, - &constant_functions, - &mut additional_functions, - ); - if additional_functions.len() > initial_count { - changed = true; - } - } - } - - // Add any newly discovered functions - for (name, func) in additional_functions { - if let std::collections::btree_map::Entry::Vacant(e) = new_functions.entry(name) { - e.insert(func); - changed = true; - } - } - } - - for (name, func) in new_functions { - assert!(!program.functions.contains_key(&name),); - program.functions.insert(name, func); - } - for const_func in constant_functions.keys() { - program.functions.remove(const_func); - } -} - -fn handle_const_arguments_helper( - lines: &mut [Line], - constant_functions: &BTreeMap, - new_functions: &mut BTreeMap, -) { - for line in lines { - match line { - Line::FunctionCall { - function_name, - args, - return_data: _, - } => { - if let Some(func) = constant_functions.get(function_name) { - // If the function has constant arguments, we need to handle them - let mut const_evals = Vec::new(); - for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { - if *is_constant { - let const_eval = arg_expr.naive_eval().unwrap_or_else(|| { - panic!("Failed to evaluate constant argument: {arg_expr}") - }); - const_evals.push((arg_var.clone(), const_eval)); - } - } - let const_funct_name = format!( - "{function_name}_{}", - const_evals - .iter() - .map(|(arg_var, const_eval)| { format!("{arg_var}={const_eval}") }) - .collect::>() - .join("_") - ); - - *function_name = const_funct_name.clone(); // change the name of the function called - // ... and remove constant arguments - *args = args - .iter() - .zip(&func.arguments) - .filter(|(_, (_, is_constant))| !is_constant) - .filter(|(_, (_, is_const))| !is_const) - .map(|(arg_expr, _)| arg_expr.clone()) - .collect(); - - if new_functions.contains_key(&const_funct_name) { - continue; - } - - let mut new_body = func.body.clone(); - replace_vars_by_const_in_lines( - &mut new_body, - &const_evals.iter().cloned().collect(), - ); - new_functions.insert( - const_funct_name.clone(), - Function { - name: const_funct_name, - arguments: func - .arguments - .iter() - .filter(|(_, is_const)| !is_const) - .cloned() - .collect(), - inlined: false, - body: new_body, - n_returned_vars: func.n_returned_vars, - }, - ); - } - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - handle_const_arguments_helper(then_branch, constant_functions, new_functions); - handle_const_arguments_helper(else_branch, constant_functions, new_functions); - } - Line::ForLoop { - body, unroll: _, .. - } => { - // TODO we should unroll before const arguments handling - handle_const_arguments_helper(body, constant_functions, new_functions); - } - _ => {} - } - } -} - -fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) { - match expr { - Expression::Value(value) => match &value { - SimpleExpr::Var(var) => { - if let Some(const_value) = map.get(var) { - *value = SimpleExpr::scalar(const_value.to_usize()); - } - } - SimpleExpr::ConstMallocAccess { .. } => { - unreachable!() - } - SimpleExpr::Constant(_) => {} - }, - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!( - !map.contains_key(array_var), - "Array {array_var} is a constant" - ); - } - replace_vars_by_const_in_expr(index, map); - } - Expression::Binary { left, right, .. } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } - Expression::Log2Ceil { value } => { - replace_vars_by_const_in_expr(value, map); - } - } -} - -fn get_function_called(lines: &[Line], function_called: &mut Vec) { - for line in lines { - match line { - Line::Match { value: _, arms } => { - for (_, statements) in arms { - get_function_called(statements, function_called); - } - } - Line::FunctionCall { function_name, .. } => { - function_called.push(function_name.clone()); - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - get_function_called(then_branch, function_called); - get_function_called(else_branch, function_called); - } - Line::ForLoop { body, .. } => { - get_function_called(body, function_called); - } - Line::Assignment { .. } - | Line::ArrayAssign { .. } - | Line::Assert { .. } - | Line::FunctionRet { .. } - | Line::Precompile { .. } - | Line::Print { .. } - | Line::DecomposeBits { .. } - | Line::CounterHint { .. } - | Line::MAlloc { .. } - | Line::Panic - | Line::Break - | Line::LocationReport { .. } => {} - } - } -} - -fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { - for line in lines { - match line { - Line::Match { value, arms } => { - replace_vars_by_const_in_expr(value, map); - for (_, statements) in arms { - replace_vars_by_const_in_lines(statements, map); - } - } - Line::Assignment { var, value } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - replace_vars_by_const_in_expr(value, map); - } - Line::ArrayAssign { - array, - index, - value, - } => { - if let SimpleExpr::Var(array_var) = array { - assert!( - !map.contains_key(array_var), - "Array {array_var} is a constant" - ); - } - replace_vars_by_const_in_expr(index, map); - replace_vars_by_const_in_expr(value, map); - } - Line::FunctionCall { - args, return_data, .. - } => { - for arg in args { - replace_vars_by_const_in_expr(arg, map); - } - for ret in return_data { - assert!( - !map.contains_key(ret), - "Return variable {ret} is a constant" - ); - } - } - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - match condition { - Boolean::Equal { left, right } | Boolean::Different { left, right } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } - } - replace_vars_by_const_in_lines(then_branch, map); - replace_vars_by_const_in_lines(else_branch, map); - } - Line::ForLoop { - body, start, end, .. - } => { - replace_vars_by_const_in_expr(start, map); - replace_vars_by_const_in_expr(end, map); - replace_vars_by_const_in_lines(body, map); - } - Line::Assert(condition) => match condition { - Boolean::Equal { left, right } | Boolean::Different { left, right } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } - }, - Line::FunctionRet { return_data } => { - for ret in return_data { - replace_vars_by_const_in_expr(ret, map); - } - } - Line::Precompile { - precompile: _, - args, - } => { - for arg in args { - replace_vars_by_const_in_expr(arg, map); - } - } - Line::Print { content, .. } => { - for var in content { - replace_vars_by_const_in_expr(var, map); - } - } - Line::DecomposeBits { var, to_decompose } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - for expr in to_decompose { - replace_vars_by_const_in_expr(expr, map); - } - } - Line::CounterHint { var } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - } - Line::MAlloc { var, size, .. } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - replace_vars_by_const_in_expr(size, map); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } -} -impl Display for SimpleLine { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.to_string_with_indent(0)) - } -} - -impl Display for VarOrConstMallocAccess { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Var(var) => write!(f, "{var}"), - Self::ConstMallocAccess { - malloc_label, - offset, - } => { - write!(f, "ConstMallocAccess({malloc_label}, {offset})") - } - } - } -} - -impl SimpleLine { - fn to_string_with_indent(&self, indent: usize) -> String { - let spaces = " ".repeat(indent); - let line_str = match self { - Self::Match { value, arms } => { - let arms_str = arms - .iter() - .enumerate() - .map(|(pattern, stmt)| { - format!( - "{} => {}", - pattern, - stmt.iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n") - ) - }) - .collect::>() - .join(", "); - - format!("match {value} {{\n{arms_str}\n{spaces}}}") - } - Self::Assignment { - var, - operation, - arg0, - arg1, - } => { - format!("{var} = {arg0} {operation} {arg1}") - } - Self::DecomposeBits { - var: result, - to_decompose, - label: _, - } => { - format!( - "{} = decompose_bits({})", - result, - to_decompose - .iter() - .map(|expr| format!("{expr}")) - .collect::>() - .join(", ") - ) - } - Self::CounterHint { var: result } => { - format!("{result} = counter_hint()") - } - Self::RawAccess { res, index, shift } => { - format!("memory[{index} + {shift}] = {res}") - } - Self::IfNotZero { - condition, - then_branch, - else_branch, - } => { - let then_str = then_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - let else_str = else_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - if else_branch.is_empty() { - format!("if {condition} != 0 {{\n{then_str}\n{spaces}}}") - } else { - format!( - "if {condition} != 0 {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" - ) - } - } - Self::FunctionCall { - function_name, - args, - return_data, - } => { - let args_str = args - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); - let return_data_str = return_data - .iter() - .map(|var| var.to_string()) - .collect::>() - .join(", "); - - if return_data.is_empty() { - format!("{function_name}({args_str})") - } else { - format!("{return_data_str} = {function_name}({args_str})") - } - } - Self::FunctionRet { return_data } => { - let return_data_str = return_data - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); - format!("return {return_data_str}") - } - Self::Precompile { precompile, args } => { - format!( - "{}({})", - &precompile.name, - args.iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", ") - ) - } - Self::Print { - line_info: _, - content, - } => { - let content_str = content - .iter() - .map(|c| format!("{c}")) - .collect::>() - .join(", "); - format!("print({content_str})") - } - Self::HintMAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - if *vectorized { - format!("{var} = malloc_vec({size}, {vectorized_len})") - } else { - format!("{var} = malloc({size})") - } - } - Self::ConstMalloc { - var, - size, - label: _, - } => { - format!("{var} = malloc({size})") - } - Self::Panic => "panic".to_string(), - Self::LocationReport { .. } => Default::default(), - }; - format!("{spaces}{line_str}") - } -} - -impl Display for SimpleFunction { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let args_str = self - .arguments - .iter() - .map(|arg| arg.to_string()) - .collect::>() - .join(", "); - - let instructions_str = self - .instructions - .iter() - .map(|line| line.to_string_with_indent(1)) - .collect::>() - .join("\n"); - - if self.instructions.is_empty() { - write!( - f, - "fn {}({}) -> {} {{}}", - self.name, args_str, self.n_returned_vars - ) - } else { - write!( - f, - "fn {}({}) -> {} {{\n{}\n}}", - self.name, args_str, self.n_returned_vars, instructions_str - ) - } - } -} - -impl Display for SimpleProgram { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut first = true; - for function in self.functions.values() { - if !first { - writeln!(f)?; - } - write!(f, "{function}")?; - first = false; - } - Ok(()) - } -} diff --git a/crates/lean_compiler/src/analysis/control_flow.rs b/crates/lean_compiler/src/analysis/control_flow.rs new file mode 100644 index 00000000..b7e2f450 --- /dev/null +++ b/crates/lean_compiler/src/analysis/control_flow.rs @@ -0,0 +1,226 @@ +//! Control flow analysis utilities. + +use super::visitors::{Visitor, VisitorResult}; +use crate::lang::Line; +use std::ops::ControlFlow; + +/// Analyzes control flow patterns in a function body +pub struct ControlFlowAnalyzer { + pub has_returns: bool, + pub return_count: usize, + pub has_nested_returns: bool, + pub has_non_exhaustive_conditions: bool, +} + +impl ControlFlowAnalyzer { + pub fn new() -> Self { + Self { + has_returns: false, + return_count: 0, + has_nested_returns: false, + has_non_exhaustive_conditions: false, + } + } + + pub fn analyze(lines: &[Line]) -> Self { + let mut analyzer = Self::new(); + for line in lines { + let _: VisitorResult<()> = analyzer.visit_line(line); + } + analyzer + } + + /// Check if this pattern needs SSA violation repair + pub fn needs_ssa_repair(&self) -> bool { + self.has_non_exhaustive_conditions && self.return_count > 1 + } + + fn analyze_line_depth(&mut self, line: &Line, depth: usize) { + match line { + Line::FunctionRet(_) => { + self.has_returns = true; + self.return_count += 1; + if depth > 0 { + self.has_nested_returns = true; + } + } + Line::IfCondition(if_cond) => { + let then_branch = &if_cond.then_branch; + let else_branch = &if_cond.else_branch; + // Non-exhaustive if else_branch is empty + if else_branch.is_empty() && !then_branch.is_empty() { + self.has_non_exhaustive_conditions = true; + } + + for line in then_branch { + self.analyze_line_depth(line, depth + 1); + } + for line in else_branch { + self.analyze_line_depth(line, depth + 1); + } + } + Line::ForLoop(for_loop) => { + for line in &for_loop.body { + self.analyze_line_depth(line, depth + 1); + } + } + Line::Match(match_stmt) => { + for (_, arm_lines) in &match_stmt.arms { + for line in arm_lines { + self.analyze_line_depth(line, depth + 1); + } + } + } + _ => {} + } + } +} + +impl Default for ControlFlowAnalyzer { + fn default() -> Self { + Self::new() + } +} + +impl Visitor for ControlFlowAnalyzer { + fn visit_line(&mut self, line: &Line) -> VisitorResult { + self.analyze_line_depth(line, 0); + ControlFlow::Continue(()) + } +} + +/// Collects all return statements in a function +pub struct ReturnCollector { + pub returns: Vec, +} + +impl ReturnCollector { + pub fn new() -> Self { + Self { + returns: Vec::new(), + } + } + + pub fn collect_from(lines: &[Line]) -> Vec { + let mut collector = Self::new(); + for line in lines { + let _: VisitorResult<()> = collector.visit_line(line); + } + collector.returns + } +} + +impl Default for ReturnCollector { + fn default() -> Self { + Self::new() + } +} + +impl Visitor for ReturnCollector { + fn visit_line(&mut self, line: &Line) -> VisitorResult { + if let Line::FunctionRet(_) = line { + self.returns.push(line.clone()); + } + // Continue walking to find nested returns + super::visitors::walk_line(self, line) + } +} + +/// Checks if a branch contains any return statements +pub fn has_return_in_branch(lines: &[Line]) -> bool { + !ReturnCollector::collect_from(lines).is_empty() +} + +/// Checks for specific SSA violation patterns +pub fn has_ssa_violation_pattern(lines: &[Line]) -> bool { + let analysis = ControlFlowAnalyzer::analyze(lines); + analysis.needs_ssa_repair() +} + +/// Finds all non-exhaustive if conditions that contain returns +pub fn find_non_exhaustive_if_with_returns(lines: &[Line]) -> Vec<&Line> { + let mut found = Vec::new(); + for line in lines { + if let Line::IfCondition(if_cond) = line + && if_cond.else_branch.is_empty() + && has_return_in_branch(&if_cond.then_branch) + { + found.push(line); + } + } + found +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{Boolean, ConstExpression, Expression, SimpleExpr}; + + #[test] + fn test_control_flow_analyzer_simple() { + let lines = vec![Line::FunctionRet( + crate::lang::ast::statement::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(42), + ))], + }, + )]; + + let analysis = ControlFlowAnalyzer::analyze(&lines); + assert!(analysis.has_returns); + assert_eq!(analysis.return_count, 1); + assert!(!analysis.has_nested_returns); + } + + #[test] + fn test_control_flow_analyzer_nested() { + let lines = vec![Line::IfCondition( + crate::lang::ast::statement::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + right: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + }, + then_branch: vec![Line::FunctionRet( + crate::lang::ast::statement::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(100), + ))], + }, + )], + else_branch: vec![], + }, + )]; + + let analysis = ControlFlowAnalyzer::analyze(&lines); + assert!(analysis.has_returns); + assert!(analysis.has_nested_returns); + assert!(analysis.has_non_exhaustive_conditions); + } + + #[test] + fn test_ssa_violation_detection() { + let lines = vec![ + Line::IfCondition(crate::lang::ast::statement::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + right: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + }, + then_branch: vec![Line::FunctionRet( + crate::lang::ast::statement::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(100), + ))], + }, + )], + else_branch: vec![], + }), + Line::FunctionRet(crate::lang::ast::statement::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(200), + ))], + }), + ]; + + assert!(has_ssa_violation_pattern(&lines)); + } +} diff --git a/crates/lean_compiler/src/analysis/mod.rs b/crates/lean_compiler/src/analysis/mod.rs new file mode 100644 index 00000000..8baf0548 --- /dev/null +++ b/crates/lean_compiler/src/analysis/mod.rs @@ -0,0 +1,5 @@ +//! Analysis infrastructure for the lean compiler. + +pub mod control_flow; +pub mod ssa; +pub mod visitors; diff --git a/crates/lean_compiler/src/analysis/ssa.rs b/crates/lean_compiler/src/analysis/ssa.rs new file mode 100644 index 00000000..c62458e2 --- /dev/null +++ b/crates/lean_compiler/src/analysis/ssa.rs @@ -0,0 +1,320 @@ +//! SSA (Static Single Assignment) analysis utilities. + +use super::visitors::{Visitor, VisitorResult}; +use crate::lang::{Boolean, Expression, Line, Var}; +use std::collections::{HashMap, HashSet}; + +/// Tracks variable assignments to detect SSA violations +pub struct SSAAnalyzer { + /// Variables that have been assigned to + pub assigned_vars: HashMap, + /// Variables with multiple assignments (SSA violations) + pub violations: HashSet, +} + +impl SSAAnalyzer { + pub fn new() -> Self { + Self { + assigned_vars: HashMap::new(), + violations: HashSet::new(), + } + } + + pub fn analyze(lines: &[Line]) -> Self { + let mut analyzer = Self::new(); + for line in lines { + let _: VisitorResult<()> = analyzer.visit_line(line); + } + analyzer + } + + pub fn has_violations(&self) -> bool { + !self.violations.is_empty() + } + + fn record_assignment(&mut self, var: &Var) { + let count = self.assigned_vars.entry(var.clone()).or_insert(0); + *count += 1; + + if *count > 1 { + self.violations.insert(var.clone()); + } + } +} + +impl Default for SSAAnalyzer { + fn default() -> Self { + Self::new() + } +} + +impl Visitor for SSAAnalyzer { + fn visit_line(&mut self, line: &Line) -> VisitorResult { + match line { + Line::Assignment(assignment) => { + self.record_assignment(&assignment.var); + } + Line::ArrayAssign { .. } => { + // Array assignments don't violate SSA for scalar variables + } + Line::MAlloc(malloc) => { + self.record_assignment(&malloc.var); + } + Line::DecomposeBits(decompose) => { + self.record_assignment(&decompose.var); + } + Line::CounterHint(hint) => { + self.record_assignment(&hint.var); + } + Line::FunctionCall(call) => { + for var in &call.return_data { + self.record_assignment(var); + } + } + _ => {} + } + + // Continue walking to analyze nested structures + super::visitors::walk_line(self, line) + } +} + +/// Context for tracking return assignments across different execution paths +#[derive(Debug, Clone)] +pub struct ReturnContext { + /// Conditions that must be true for this return to execute + pub conditions: Vec, + /// Variable assignments in this return context + pub assignments: Vec<(Var, Expression)>, +} + +impl ReturnContext { + pub fn new() -> Self { + Self { + conditions: Vec::new(), + assignments: Vec::new(), + } + } + + pub fn with_condition(mut self, condition: Boolean) -> Self { + self.conditions.push(condition); + self + } + + pub fn with_assignment(mut self, var: Var, expr: Expression) -> Self { + self.assignments.push((var, expr)); + self + } +} + +impl Default for ReturnContext { + fn default() -> Self { + Self::new() + } +} + +/// Collects all return contexts from a function body +pub struct ReturnContextCollector { + pub contexts: Vec, +} + +impl ReturnContextCollector { + pub fn new() -> Self { + Self { + contexts: Vec::new(), + } + } + + pub fn collect_from(lines: &[Line], return_vars: &[Var]) -> Vec { + let mut collector = Self::new(); + collector.collect_recursive(lines, return_vars, Vec::new()); + collector.contexts + } + + fn collect_recursive( + &mut self, + lines: &[Line], + return_vars: &[Var], + current_conditions: Vec, + ) { + for line in lines { + match line { + Line::Assignment(assignment) if return_vars.contains(&assignment.var) => { + let context = ReturnContext { + conditions: current_conditions.clone(), + assignments: vec![(assignment.var.clone(), assignment.value.clone())], + }; + self.contexts.push(context); + } + Line::IfCondition(if_cond) => { + let condition = &if_cond.condition; + let then_branch = &if_cond.then_branch; + let else_branch = &if_cond.else_branch; + // Collect from then branch with added condition + let mut then_conditions = current_conditions.clone(); + then_conditions.push(condition.clone()); + self.collect_recursive(then_branch, return_vars, then_conditions); + + // Collect from else branch with negated condition + if !else_branch.is_empty() { + let negated_condition = match condition { + Boolean::Equal { left, right } => Boolean::Different { + left: left.clone(), + right: right.clone(), + }, + Boolean::Different { left, right } => Boolean::Equal { + left: left.clone(), + right: right.clone(), + }, + }; + let mut else_conditions = current_conditions.clone(); + else_conditions.push(negated_condition); + self.collect_recursive(else_branch, return_vars, else_conditions); + } + } + Line::ForLoop(for_loop) => { + // For simplicity, treat loop body as unconditional context + self.collect_recursive(&for_loop.body, return_vars, current_conditions.clone()); + } + Line::Match(match_stmt) => { + for (_, arm_lines) in &match_stmt.arms { + self.collect_recursive(arm_lines, return_vars, current_conditions.clone()); + } + } + _ => {} + } + } + } +} + +impl Default for ReturnContextCollector { + fn default() -> Self { + Self::new() + } +} + +/// Detects potential SSA violations in return patterns +pub fn detect_return_ssa_violations(lines: &[Line], return_vars: &[Var]) -> bool { + let contexts = ReturnContextCollector::collect_from(lines, return_vars); + + // If we have multiple contexts that could assign to the same variable, + // we have a potential SSA violation + if contexts.len() <= 1 { + return false; + } + + // Check if any return variable is assigned in multiple contexts + for return_var in return_vars { + let contexts_with_var = contexts + .iter() + .filter(|ctx| ctx.assignments.iter().any(|(var, _)| var == return_var)) + .count(); + + if contexts_with_var > 1 { + return true; + } + } + + false +} + +/// Creates a simple SSA repair strategy by restructuring control flow +pub fn create_ssa_repair_strategy(contexts: &[ReturnContext]) -> Vec { + if contexts.is_empty() { + return Vec::new(); + } + + // For now, implement a simple strategy: use the first context + // In a complete implementation, we'd build proper if-else chains + let first_context = &contexts[0]; + + let mut result = Vec::new(); + for (var, expr) in &first_context.assignments { + result.push(Line::Assignment(crate::lang::ast::statement::Assignment { + var: var.clone(), + value: expr.clone(), + })); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{Assignment, ConstExpression, IfCondition, SimpleExpr}; + + #[test] + fn test_ssa_analyzer_no_violations() { + let lines = vec![ + Line::Assignment(Assignment { + var: "x".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + }), + Line::Assignment(Assignment { + var: "y".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(2))), + }), + ]; + + let analyzer = SSAAnalyzer::analyze(&lines); + assert!(!analyzer.has_violations()); + } + + #[test] + fn test_ssa_analyzer_with_violations() { + let lines = vec![ + Line::Assignment(Assignment { + var: "x".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + }), + Line::Assignment(Assignment { + var: "x".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(2))), + }), + ]; + + let analyzer = SSAAnalyzer::analyze(&lines); + assert!(analyzer.has_violations()); + assert!(analyzer.violations.contains("x")); + } + + #[test] + fn test_return_context_collection() { + let lines = vec![Line::Assignment(Assignment { + var: "result".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(42))), + })]; + + let return_vars = vec!["result".to_string()]; + let contexts = ReturnContextCollector::collect_from(&lines, &return_vars); + + assert_eq!(contexts.len(), 1); + assert_eq!(contexts[0].assignments.len(), 1); + assert_eq!(contexts[0].assignments[0].0, "result"); + } + + #[test] + fn test_detect_return_ssa_violations() { + let lines = vec![ + Line::IfCondition(IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + right: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + }, + then_branch: vec![Line::Assignment(Assignment { + var: "result".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(100))), + })], + else_branch: vec![], + }), + Line::Assignment(Assignment { + var: "result".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(200))), + }), + ]; + + let return_vars = vec!["result".to_string()]; + assert!(detect_return_ssa_violations(&lines, &return_vars)); + } +} diff --git a/crates/lean_compiler/src/analysis/visitors.rs b/crates/lean_compiler/src/analysis/visitors.rs new file mode 100644 index 00000000..04b18a95 --- /dev/null +++ b/crates/lean_compiler/src/analysis/visitors.rs @@ -0,0 +1,392 @@ +//! Visitor traits for AST traversal. + +use crate::lang::{self, Boolean, Expression, Function, Line, Program, SimpleExpr}; +use std::ops::ControlFlow; + +/// Result type for visitor operations +pub type VisitorResult = ControlFlow; + +/// Immutable visitor trait for read-only AST traversal +pub trait Visitor { + fn visit_program(&mut self, program: &Program) -> VisitorResult { + walk_program(self, program) + } + + fn visit_function(&mut self, function: &Function) -> VisitorResult { + walk_function(self, function) + } + + fn visit_line(&mut self, line: &Line) -> VisitorResult { + walk_line(self, line) + } + + fn visit_expression(&mut self, expr: &Expression) -> VisitorResult { + walk_expression(self, expr) + } + + fn visit_simple_expr(&mut self, simple_expr: &SimpleExpr) -> VisitorResult { + walk_simple_expr(self, simple_expr) + } + + fn visit_boolean(&mut self, condition: &Boolean) -> VisitorResult { + walk_boolean(self, condition) + } +} + +/// Mutable visitor trait for AST transformations +pub trait MutVisitor { + fn visit_program(&mut self, program: &mut Program) -> VisitorResult { + walk_program_mut(self, program) + } + + fn visit_function(&mut self, function: &mut Function) -> VisitorResult { + walk_function_mut(self, function) + } + + fn visit_line(&mut self, line: &mut Line) -> VisitorResult { + walk_line_mut(self, line) + } + + fn visit_expression(&mut self, expr: &mut Expression) -> VisitorResult { + walk_expression_mut(self, expr) + } + + fn visit_simple_expr(&mut self, simple_expr: &mut SimpleExpr) -> VisitorResult { + walk_simple_expr_mut(self, simple_expr) + } + + fn visit_boolean(&mut self, condition: &mut Boolean) -> VisitorResult { + walk_boolean_mut(self, condition) + } +} + +// Immutable walking functions (now generic) +pub fn walk_program(visitor: &mut V, program: &Program) -> VisitorResult +where + V: Visitor + ?Sized, +{ + for function in program.functions.values() { + visitor.visit_function(function)?; + } + ControlFlow::Continue(()) +} + +pub fn walk_function(visitor: &mut V, function: &Function) -> VisitorResult +where + V: Visitor + ?Sized, +{ + for line in &function.body { + visitor.visit_line(line)?; + } + ControlFlow::Continue(()) +} + +pub fn walk_line(visitor: &mut V, line: &Line) -> VisitorResult +where + V: Visitor + ?Sized, +{ + match line { + Line::Match(match_stmt) => { + visitor.visit_expression(&match_stmt.value)?; + for (_, arm_lines) in &match_stmt.arms { + for arm_line in arm_lines { + visitor.visit_line(arm_line)?; + } + } + } + Line::Assignment(assignment) => { + visitor.visit_expression(&assignment.value)?; + } + Line::ArrayAssign(array_assign) => { + visitor.visit_simple_expr(&array_assign.array)?; + visitor.visit_expression(&array_assign.index)?; + visitor.visit_expression(&array_assign.value)?; + } + Line::Assert(assert) => { + visitor.visit_boolean(&assert.condition)?; + } + Line::IfCondition(if_cond) => { + visitor.visit_boolean(&if_cond.condition)?; + for line in &if_cond.then_branch { + visitor.visit_line(line)?; + } + for line in &if_cond.else_branch { + visitor.visit_line(line)?; + } + } + Line::ForLoop(for_loop) => { + visitor.visit_expression(&for_loop.start)?; + visitor.visit_expression(&for_loop.end)?; + for line in &for_loop.body { + visitor.visit_line(line)?; + } + } + Line::FunctionCall(call) => { + for arg in &call.args { + visitor.visit_expression(arg)?; + } + } + Line::FunctionRet(ret) => { + for expr in &ret.return_data { + visitor.visit_expression(expr)?; + } + } + Line::Precompile(precompile) => { + for arg in &precompile.args { + visitor.visit_expression(arg)?; + } + } + Line::Print(print) => { + for expr in &print.content { + visitor.visit_expression(expr)?; + } + } + Line::MAlloc(malloc) => { + visitor.visit_expression(&malloc.size)?; + visitor.visit_expression(&malloc.vectorized_len)?; + } + Line::DecomposeBits(decompose) => { + for expr in &decompose.to_decompose { + visitor.visit_expression(expr)?; + } + } + Line::Break(_) | Line::Panic(_) | Line::CounterHint(_) | Line::LocationReport(_) => {} + } + ControlFlow::Continue(()) +} + +pub fn walk_expression(visitor: &mut V, expr: &Expression) -> VisitorResult +where + V: Visitor + ?Sized, +{ + match expr { + Expression::Value(simple_expr) => visitor.visit_simple_expr(simple_expr)?, + Expression::ArrayAccess { array, index } => { + visitor.visit_simple_expr(array)?; + visitor.visit_expression(index)?; + } + Expression::Binary { left, right, .. } => { + visitor.visit_expression(left)?; + visitor.visit_expression(right)?; + } + Expression::Log2Ceil { value } => { + visitor.visit_expression(value)?; + } + } + ControlFlow::Continue(()) +} + +pub fn walk_simple_expr(_visitor: &mut V, _simple_expr: &SimpleExpr) -> VisitorResult +where + V: Visitor + ?Sized, +{ + // SimpleExpr is leaf-level, no further traversal needed + ControlFlow::Continue(()) +} + +pub fn walk_boolean(visitor: &mut V, condition: &Boolean) -> VisitorResult +where + V: Visitor + ?Sized, +{ + match condition { + Boolean::Equal { left, right } | Boolean::Different { left, right } => { + visitor.visit_expression(left)?; + visitor.visit_expression(right)?; + } + } + ControlFlow::Continue(()) +} + +// Mutable walking functions (now generic) +pub fn walk_program_mut(visitor: &mut V, program: &mut Program) -> VisitorResult +where + V: MutVisitor + ?Sized, +{ + for function in program.functions.values_mut() { + visitor.visit_function(function)?; + } + ControlFlow::Continue(()) +} + +pub fn walk_function_mut(visitor: &mut V, function: &mut Function) -> VisitorResult +where + V: MutVisitor + ?Sized, +{ + for line in &mut function.body { + visitor.visit_line(line)?; + } + ControlFlow::Continue(()) +} + +pub fn walk_line_mut(visitor: &mut V, line: &mut Line) -> VisitorResult +where + V: MutVisitor + ?Sized, +{ + match line { + Line::Match(lang::Match { value, arms }) => { + visitor.visit_expression(value)?; + for (_, arm_lines) in arms { + for arm_line in arm_lines { + visitor.visit_line(arm_line)?; + } + } + } + Line::Assignment(lang::Assignment { value, .. }) => { + visitor.visit_expression(value)?; + } + Line::ArrayAssign(lang::ArrayAssign { + array, + index, + value, + }) => { + visitor.visit_simple_expr(array)?; + visitor.visit_expression(index)?; + visitor.visit_expression(value)?; + } + Line::Assert(lang::Assert { condition }) => { + visitor.visit_boolean(condition)?; + } + Line::IfCondition(lang::IfCondition { + condition, + then_branch, + else_branch, + }) => { + visitor.visit_boolean(condition)?; + for line in then_branch { + visitor.visit_line(line)?; + } + for line in else_branch { + visitor.visit_line(line)?; + } + } + Line::ForLoop(lang::ForLoop { + start, end, body, .. + }) => { + visitor.visit_expression(start)?; + visitor.visit_expression(end)?; + for line in body { + visitor.visit_line(line)?; + } + } + Line::FunctionCall(lang::FunctionCall { args, .. }) => { + for arg in args { + visitor.visit_expression(arg)?; + } + } + Line::FunctionRet(lang::FunctionRet { return_data }) => { + for expr in return_data { + visitor.visit_expression(expr)?; + } + } + Line::Precompile(lang::PrecompileStmt { args, .. }) => { + for arg in args { + visitor.visit_expression(arg)?; + } + } + Line::Print(lang::Print { content, .. }) => { + for expr in content { + visitor.visit_expression(expr)?; + } + } + Line::MAlloc(lang::MAlloc { + size, + vectorized_len, + .. + }) => { + visitor.visit_expression(size)?; + visitor.visit_expression(vectorized_len)?; + } + Line::DecomposeBits(lang::DecomposeBits { to_decompose, .. }) => { + for expr in to_decompose { + visitor.visit_expression(expr)?; + } + } + Line::Break(_) | Line::Panic(_) | Line::CounterHint(_) | Line::LocationReport(_) => {} + } + ControlFlow::Continue(()) +} + +pub fn walk_expression_mut(visitor: &mut V, expr: &mut Expression) -> VisitorResult +where + V: MutVisitor + ?Sized, +{ + match expr { + Expression::Value(simple_expr) => visitor.visit_simple_expr(simple_expr)?, + Expression::ArrayAccess { array, index } => { + visitor.visit_simple_expr(array)?; + visitor.visit_expression(index)?; + } + Expression::Binary { left, right, .. } => { + visitor.visit_expression(left)?; + visitor.visit_expression(right)?; + } + Expression::Log2Ceil { value } => { + visitor.visit_expression(value)?; + } + } + ControlFlow::Continue(()) +} + +pub fn walk_simple_expr_mut( + _visitor: &mut V, + _simple_expr: &mut SimpleExpr, +) -> VisitorResult +where + V: MutVisitor + ?Sized, +{ + // SimpleExpr is leaf-level, no further traversal needed + ControlFlow::Continue(()) +} + +pub fn walk_boolean_mut(visitor: &mut V, condition: &mut Boolean) -> VisitorResult +where + V: MutVisitor + ?Sized, +{ + match condition { + Boolean::Equal { left, right } | Boolean::Different { left, right } => { + visitor.visit_expression(left)?; + visitor.visit_expression(right)?; + } + } + ControlFlow::Continue(()) +} + +/// Utility for collecting items during traversal +pub struct Collector { + pub items: Vec, +} + +impl Collector { + pub fn new() -> Self { + Self { items: Vec::new() } + } + + pub fn collect(mut self, item: T) -> Self { + self.items.push(item); + self + } +} + +impl Default for Collector { + fn default() -> Self { + Self::new() + } +} + +/// Utility for finding specific patterns +pub struct PatternFinder { + pub predicate: F, + pub found: Option, +} + +impl PatternFinder +where + F: Fn(&T) -> bool, +{ + pub fn new(predicate: F) -> Self { + Self { + predicate, + found: None, + } + } +} diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs deleted file mode 100644 index 5d47ca2e..00000000 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ /dev/null @@ -1,774 +0,0 @@ -use crate::{F, a_simplify_lang::*, ir::*, lang::*, precompiles::*}; -use lean_vm::*; -use p3_field::Field; -use std::{ - borrow::Borrow, - collections::{BTreeMap, BTreeSet}, -}; -use utils::ToUsize; - -#[derive(Default)] -struct Compiler { - bytecode: BTreeMap>, - match_blocks: Vec>>, // each match = many bytecode blocks, each bytecode block = many instructions - if_counter: usize, - call_counter: usize, - func_name: String, - var_positions: BTreeMap, // var -> memory offset from fp - args_count: usize, - stack_size: usize, - const_mallocs: BTreeMap, // const_malloc_label -> start = memory offset from fp -} - -impl Compiler { - fn get_offset(&self, var: &VarOrConstMallocAccess) -> ConstExpression { - match var { - VarOrConstMallocAccess::Var(var) => (*self - .var_positions - .get(var) - .unwrap_or_else(|| panic!("Variable {var} not in scope"))) - .into(), - VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - } => ConstExpression::Binary { - left: Box::new( - self.const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, - } - } -} - -impl SimpleExpr { - fn to_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant { - match self { - Self::Var(var) => IntermediaryMemOrFpOrConstant::MemoryAfterFp { - offset: compiler.get_offset(&var.clone().into()), - }, - Self::Constant(c) => IntermediaryMemOrFpOrConstant::Constant(c.clone()), - Self::ConstMallocAccess { - malloc_label, - offset, - } => IntermediaryMemOrFpOrConstant::MemoryAfterFp { - offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *malloc_label, - offset: offset.clone(), - }), - }, - } - } -} - -impl IntermediateValue { - fn from_simple_expr(expr: &SimpleExpr, compiler: &Compiler) -> Self { - match expr { - SimpleExpr::Var(var) => Self::MemoryAfterFp { - offset: compiler.get_offset(&var.clone().into()), - }, - SimpleExpr::Constant(c) => Self::Constant(c.clone()), - SimpleExpr::ConstMallocAccess { - malloc_label, - offset, - } => Self::MemoryAfterFp { - offset: ConstExpression::Binary { - left: Box::new( - compiler - .const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, - }, - } - } - - fn from_var_or_const_malloc_access( - var_or_const: &VarOrConstMallocAccess, - compiler: &Compiler, - ) -> Self { - Self::from_simple_expr(&var_or_const.clone().into(), compiler) - } -} - -pub fn compile_to_intermediate_bytecode( - simple_program: SimpleProgram, -) -> Result { - let mut compiler = Compiler::default(); - let mut memory_sizes = BTreeMap::new(); - - for function in simple_program.functions.values() { - let instructions = compile_function(function, &mut compiler)?; - compiler - .bytecode - .insert(Label::function(&function.name), instructions); - memory_sizes.insert(function.name.clone(), compiler.stack_size); - } - - Ok(IntermediateBytecode { - bytecode: compiler.bytecode, - match_blocks: compiler.match_blocks, - memory_size_per_function: memory_sizes, - }) -} - -fn compile_function( - function: &SimpleFunction, - compiler: &mut Compiler, -) -> Result, String> { - let mut internal_vars = find_internal_vars(&function.instructions); - - internal_vars.retain(|var| !function.arguments.contains(var)); - - // memory layout: pc, fp, args, return_vars, internal_vars - let mut stack_pos = 2; // Reserve space for pc and fp - let mut var_positions = BTreeMap::new(); - - for (i, var) in function.arguments.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += function.arguments.len(); - - stack_pos += function.n_returned_vars; - - for (i, var) in internal_vars.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += internal_vars.len(); - - compiler.func_name = function.name.clone(); - compiler.var_positions = var_positions; - compiler.stack_size = stack_pos; - compiler.args_count = function.arguments.len(); - - let mut declared_vars: BTreeSet = function.arguments.iter().cloned().collect(); - compile_lines(&function.instructions, compiler, None, &mut declared_vars) -} - -fn compile_lines( - lines: &[SimpleLine], - compiler: &mut Compiler, - final_jump: Option