From c60a2f7ae700bcbdc188104ac07fec2bcaf54aa2 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 20:21:29 +0200 Subject: [PATCH 01/39] compiler: reorg lang --- crates/lean_compiler/src/lang.rs | 670 ------------------ crates/lean_compiler/src/lang/ast/expr.rs | 196 +++++ crates/lean_compiler/src/lang/ast/mod.rs | 11 + crates/lean_compiler/src/lang/ast/program.rs | 86 +++ crates/lean_compiler/src/lang/ast/stmt.rs | 246 +++++++ crates/lean_compiler/src/lang/ast/types.rs | 27 + crates/lean_compiler/src/lang/mod.rs | 7 + .../lean_compiler/src/lang/values/constant.rs | 181 +++++ crates/lean_compiler/src/lang/values/mod.rs | 7 + .../lean_compiler/src/lang/values/variable.rs | 4 + 10 files changed, 765 insertions(+), 670 deletions(-) delete mode 100644 crates/lean_compiler/src/lang.rs create mode 100644 crates/lean_compiler/src/lang/ast/expr.rs create mode 100644 crates/lean_compiler/src/lang/ast/mod.rs create mode 100644 crates/lean_compiler/src/lang/ast/program.rs create mode 100644 crates/lean_compiler/src/lang/ast/stmt.rs create mode 100644 crates/lean_compiler/src/lang/ast/types.rs create mode 100644 crates/lean_compiler/src/lang/mod.rs create mode 100644 crates/lean_compiler/src/lang/values/constant.rs create mode 100644 crates/lean_compiler/src/lang/values/mod.rs create mode 100644 crates/lean_compiler/src/lang/values/variable.rs diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs deleted file mode 100644 index 1898c7bd..00000000 --- a/crates/lean_compiler/src/lang.rs +++ /dev/null @@ -1,670 +0,0 @@ -use lean_vm::*; -use p3_field::PrimeCharacteristicRing; -use p3_util::log2_ceil_usize; -use std::collections::BTreeMap; -use std::fmt::{Display, Formatter}; -use utils::ToUsize; - -use crate::{F, ir::HighLevelOperation, precompiles::Precompile}; - -#[derive(Debug, Clone)] -pub struct Program { - pub functions: BTreeMap, -} - -#[derive(Debug, Clone)] -pub struct Function { - pub name: String, - pub arguments: Vec<(Var, bool)>, // (name, is_const) - pub inlined: bool, - pub n_returned_vars: usize, - pub body: Vec, -} - -impl Function { - pub fn has_const_arguments(&self) -> bool { - self.arguments.iter().any(|(_, is_const)| *is_const) - } -} - -pub type Var = String; -pub type ConstMallocLabel = usize; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum SimpleExpr { - Var(Var), - Constant(ConstExpression), - ConstMallocAccess { - malloc_label: ConstMallocLabel, - offset: ConstExpression, - }, -} - -impl SimpleExpr { - pub fn zero() -> Self { - Self::scalar(0) - } - - pub fn one() -> Self { - Self::scalar(1) - } - - pub fn scalar(scalar: usize) -> Self { - Self::Constant(ConstantValue::Scalar(scalar).into()) - } - - pub const fn is_constant(&self) -> bool { - matches!(self, Self::Constant(_)) - } - - pub fn simplify_if_const(&self) -> Self { - if let Self::Constant(constant) = self { - return constant.try_naive_simplification().into(); - } - self.clone() - } -} - -impl From for SimpleExpr { - fn from(constant: ConstantValue) -> Self { - Self::Constant(constant.into()) - } -} - -impl From for SimpleExpr { - fn from(constant: ConstExpression) -> Self { - Self::Constant(constant) - } -} - -impl From for SimpleExpr { - fn from(var: Var) -> Self { - Self::Var(var) - } -} - -impl SimpleExpr { - pub fn as_constant(&self) -> Option { - match self { - Self::Var(_) => None, - Self::Constant(constant) => Some(constant.clone()), - Self::ConstMallocAccess { .. } => None, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Boolean { - Equal { left: Expression, right: Expression }, - Different { left: Expression, right: Expression }, -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ConstantValue { - Scalar(usize), - PublicInputStart, - PointerToZeroVector, // In the memory of chunks of 8 field elements - PointerToOneVector, // In the memory of chunks of 8 field elements - FunctionSize { function_name: Label }, - Label(Label), - MatchBlockSize { match_index: usize }, - MatchFirstBlockStart { match_index: usize }, -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ConstExpression { - Value(ConstantValue), - Binary { - left: Box, - operation: HighLevelOperation, - right: Box, - }, - Log2Ceil { - value: Box, - }, -} - -impl From for ConstExpression { - fn from(value: usize) -> Self { - Self::Value(ConstantValue::Scalar(value)) - } -} - -impl TryFrom for ConstExpression { - type Error = (); - - fn try_from(value: Expression) -> Result { - match value { - Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), - Expression::Value(_) => Err(()), - Expression::ArrayAccess { .. } => Err(()), - Expression::Binary { - left, - operation, - right, - } => { - let left_expr = Self::try_from(*left)?; - let right_expr = Self::try_from(*right)?; - Ok(Self::Binary { - left: Box::new(left_expr), - operation, - right: Box::new(right_expr), - }) - } - Expression::Log2Ceil { value } => { - let value_expr = Self::try_from(*value)?; - Ok(Self::Log2Ceil { - value: Box::new(value_expr), - }) - } - } - } -} - -impl ConstExpression { - pub const fn zero() -> Self { - Self::scalar(0) - } - - pub const fn one() -> Self { - Self::scalar(1) - } - - pub const fn label(label: Label) -> Self { - Self::Value(ConstantValue::Label(label)) - } - - pub const fn scalar(scalar: usize) -> Self { - Self::Value(ConstantValue::Scalar(scalar)) - } - - pub const fn function_size(function_name: Label) -> Self { - Self::Value(ConstantValue::FunctionSize { function_name }) - } - pub fn eval_with(&self, func: &EvalFn) -> Option - where - EvalFn: Fn(&ConstantValue) -> Option, - { - match self { - Self::Value(value) => func(value), - Self::Binary { - left, - operation, - right, - } => Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?)), - Self::Log2Ceil { value } => { - let value = value.eval_with(func)?; - Some(F::from_usize(log2_ceil_usize(value.to_usize()))) - } - } - } - - pub fn naive_eval(&self) -> Option { - self.eval_with(&|value| match value { - ConstantValue::Scalar(scalar) => Some(F::from_usize(*scalar)), - _ => None, - }) - } - - pub fn try_naive_simplification(&self) -> Self { - if let Some(value) = self.naive_eval() { - Self::scalar(value.to_usize()) - } else { - self.clone() - } - } -} - -impl From for ConstExpression { - fn from(value: ConstantValue) -> Self { - Self::Value(value) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Expression { - Value(SimpleExpr), - ArrayAccess { - array: SimpleExpr, - index: Box, - }, - Binary { - left: Box, - operation: HighLevelOperation, - right: Box, - }, - Log2Ceil { - value: Box, - }, // only for const expressions -} - -impl From for Expression { - fn from(value: SimpleExpr) -> Self { - Self::Value(value) - } -} - -impl From for Expression { - fn from(var: Var) -> Self { - Self::Value(var.into()) - } -} - -impl Expression { - pub fn naive_eval(&self) -> Option { - self.eval_with( - &|value: &SimpleExpr| value.as_constant()?.naive_eval(), - &|_, _| None, - ) - } - - pub fn eval_with(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option - where - ValueFn: Fn(&SimpleExpr) -> Option, - ArrayFn: Fn(&SimpleExpr, F) -> Option, - { - match self { - Self::Value(value) => value_fn(value), - Self::ArrayAccess { array, index } => { - array_fn(array, index.eval_with(value_fn, array_fn)?) - } - Self::Binary { - left, - operation, - right, - } => Some(operation.eval( - left.eval_with(value_fn, array_fn)?, - right.eval_with(value_fn, array_fn)?, - )), - Self::Log2Ceil { value } => { - let value = value.eval_with(value_fn, array_fn)?; - Some(F::from_usize(log2_ceil_usize(value.to_usize()))) - } - } - } - - pub fn scalar(scalar: usize) -> Self { - SimpleExpr::scalar(scalar).into() - } - - pub fn zero() -> Self { - Self::scalar(0) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Line { - Match { - value: Expression, - arms: Vec<(usize, Vec)>, - }, - Assignment { - var: Var, - value: Expression, - }, - ArrayAssign { - // array[index] = value - array: SimpleExpr, - index: Expression, - value: Expression, - }, - Assert(Boolean), - IfCondition { - condition: Boolean, - then_branch: Vec, - else_branch: Vec, - }, - ForLoop { - iterator: Var, - start: Expression, - end: Expression, - body: Vec, - rev: bool, - unroll: bool, - }, - FunctionCall { - function_name: String, - args: Vec, - return_data: Vec, - }, - FunctionRet { - return_data: Vec, - }, - Precompile { - precompile: Precompile, - args: Vec, - }, - Break, - Panic, - // Hints: - Print { - line_info: String, - content: Vec, - }, - MAlloc { - var: Var, - size: Expression, - vectorized: bool, - vectorized_len: Expression, - }, - DecomposeBits { - var: Var, // a pointer to 31 * len(to_decompose) field elements, containing the bits of "to_decompose" - to_decompose: Vec, - }, - CounterHint { - var: Var, - }, - // noop, debug purpose only - LocationReport { - location: SourceLineNumber, - }, -} -impl Display for Expression { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Value(val) => write!(f, "{val}"), - Self::ArrayAccess { array, index } => { - write!(f, "{array}[{index}]") - } - Self::Binary { - left, - operation, - right, - } => { - write!(f, "({left} {operation} {right})") - } - Self::Log2Ceil { value } => { - write!(f, "log2_ceil({value})") - } - } - } -} - -impl Line { - fn to_string_with_indent(&self, indent: usize) -> String { - let spaces = " ".repeat(indent); - let line_str = match self { - Self::LocationReport { .. } => { - // print nothing - Default::default() - } - Self::Match { value, arms } => { - let arms_str = arms - .iter() - .map(|(const_expr, body)| { - let body_str = body - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - format!("{const_expr} => {{\n{body_str}\n{spaces}}}") - }) - .collect::>() - .join("\n"); - format!("match {value} {{\n{arms_str}\n{spaces}}}") - } - Self::Assignment { var, value } => { - format!("{var} = {value}") - } - Self::ArrayAssign { - array, - index, - value, - } => { - format!("{array}[{index}] = {value}") - } - Self::Assert(condition) => format!("assert {condition}"), - Self::IfCondition { - condition, - then_branch, - else_branch, - } => { - let then_str = then_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - let else_str = else_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - if else_branch.is_empty() { - format!("if {condition} {{\n{then_str}\n{spaces}}}") - } else { - format!( - "if {condition} {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" - ) - } - } - Self::CounterHint { var } => { - format!("{var} = counter_hint({var})") - } - Self::ForLoop { - iterator, - start, - end, - body, - rev, - unroll, - } => { - let body_str = body - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - format!( - "for {} in {}{}..{} {}{{\n{}\n{}}}", - iterator, - start, - if *rev { "rev " } else { "" }, - end, - if *unroll { "unroll " } else { "" }, - body_str, - spaces - ) - } - Self::FunctionCall { - function_name, - args, - return_data, - } => { - let args_str = args - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); - let return_data_str = return_data - .iter() - .map(|var| var.to_string()) - .collect::>() - .join(", "); - - if return_data.is_empty() { - format!("{function_name}({args_str})") - } else { - format!("{return_data_str} = {function_name}({args_str})") - } - } - Self::FunctionRet { return_data } => { - let return_data_str = return_data - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); - format!("return {return_data_str}") - } - Self::Precompile { precompile, args } => { - format!( - "{}({})", - precompile.name, - args.iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", ") - ) - } - Self::Print { - line_info: _, - content, - } => { - let content_str = content - .iter() - .map(|c| format!("{c}")) - .collect::>() - .join(", "); - format!("print({content_str})") - } - Self::MAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - if *vectorized { - format!("{var} = malloc_vec({size}, {vectorized_len})") - } else { - format!("{var} = malloc({size})") - } - } - Self::DecomposeBits { var, to_decompose } => { - format!( - "{} = decompose_bits({})", - var, - to_decompose - .iter() - .map(|expr| expr.to_string()) - .collect::>() - .join(", ") - ) - } - Self::Break => "break".to_string(), - Self::Panic => "panic".to_string(), - }; - format!("{spaces}{line_str}") - } -} - -impl Display for Boolean { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Equal { left, right } => { - write!(f, "{left} == {right}") - } - Self::Different { left, right } => { - write!(f, "{left} != {right}") - } - } - } -} - -impl Display for ConstantValue { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Scalar(scalar) => write!(f, "{scalar}"), - Self::PublicInputStart => write!(f, "@public_input_start"), - Self::PointerToZeroVector => write!(f, "@pointer_to_zero_vector"), - Self::PointerToOneVector => write!(f, "@pointer_to_one_vector"), - Self::FunctionSize { function_name } => { - write!(f, "@function_size_{function_name}") - } - Self::Label(label) => write!(f, "{label}"), - Self::MatchFirstBlockStart { match_index } => { - write!(f, "@match_first_block_start_{match_index}") - } - Self::MatchBlockSize { match_index } => { - write!(f, "@match_block_size_{match_index}") - } - } - } -} - -impl Display for SimpleExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Var(var) => write!(f, "{var}"), - Self::Constant(constant) => write!(f, "{constant}"), - Self::ConstMallocAccess { - malloc_label, - offset, - } => { - write!(f, "malloc_access({malloc_label}, {offset})") - } - } - } -} - -impl Display for ConstExpression { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Value(value) => write!(f, "{value}"), - Self::Binary { - left, - operation, - right, - } => { - write!(f, "({left} {operation} {right})") - } - Self::Log2Ceil { value } => { - write!(f, "log2_ceil({value})") - } - } - } -} - -impl Display for Line { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.to_string_with_indent(0)) - } -} - -impl Display for Program { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut first = true; - for function in self.functions.values() { - if !first { - writeln!(f)?; - } - write!(f, "{function}")?; - first = false; - } - Ok(()) - } -} - -impl Display for Function { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let args_str = self - .arguments - .iter() - .map(|arg| match arg { - (name, true) => format!("const {name}"), - (name, false) => name.to_string(), - }) - .collect::>() - .join(", "); - - let instructions_str = self - .body - .iter() - .map(|line| line.to_string_with_indent(1)) - .collect::>() - .join("\n"); - - if self.body.is_empty() { - write!( - f, - "fn {}({}) -> {} {{}}", - self.name, args_str, self.n_returned_vars - ) - } else { - write!( - f, - "fn {}({}) -> {} {{\n{}\n}}", - self.name, args_str, self.n_returned_vars, instructions_str - ) - } - } -} diff --git a/crates/lean_compiler/src/lang/ast/expr.rs b/crates/lean_compiler/src/lang/ast/expr.rs new file mode 100644 index 00000000..2bd75789 --- /dev/null +++ b/crates/lean_compiler/src/lang/ast/expr.rs @@ -0,0 +1,196 @@ +//! Expression types for the AST. + +use p3_field::PrimeCharacteristicRing; +use p3_util::log2_ceil_usize; +use std::fmt::{Display, Formatter}; + +use crate::{F, ir::HighLevelOperation}; +use crate::lang::values::{ConstExpression, ConstantValue, Var, ConstMallocLabel}; +use utils::ToUsize; + +/// Simple expression that can be a variable, constant, or memory access. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SimpleExpr { + /// Variable reference. + Var(Var), + /// Constant value. + Constant(ConstExpression), + /// Access to const malloc memory. + ConstMallocAccess { + malloc_label: ConstMallocLabel, + offset: ConstExpression, + }, +} + +impl SimpleExpr { + /// Creates a zero constant. + pub fn zero() -> Self { + Self::scalar(0) + } + + /// Creates a one constant. + pub fn one() -> Self { + Self::scalar(1) + } + + /// Creates a scalar constant. + pub fn scalar(scalar: usize) -> Self { + Self::Constant(ConstantValue::Scalar(scalar).into()) + } + + /// Returns true if this expression is a constant. + pub const fn is_constant(&self) -> bool { + matches!(self, Self::Constant(_)) + } + + /// Simplifies the expression if it's a constant. + pub fn simplify_if_const(&self) -> Self { + if let Self::Constant(constant) = self { + return constant.try_naive_simplification().into(); + } + self.clone() + } + + /// Extracts the constant expression if this is a constant. + pub fn as_constant(&self) -> Option { + match self { + Self::Var(_) => None, + Self::Constant(constant) => Some(constant.clone()), + Self::ConstMallocAccess { .. } => None, + } + } +} + +impl From for SimpleExpr { + fn from(constant: ConstantValue) -> Self { + Self::Constant(constant.into()) + } +} + +impl From for SimpleExpr { + fn from(constant: ConstExpression) -> Self { + Self::Constant(constant) + } +} + +impl From for SimpleExpr { + fn from(var: Var) -> Self { + Self::Var(var) + } +} + +/// Complex expression supporting operations and array access. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Expression { + /// Simple value expression. + Value(SimpleExpr), + /// Array element access. + ArrayAccess { + array: SimpleExpr, + index: Box, + }, + /// Binary operation. + Binary { + left: Box, + operation: HighLevelOperation, + right: Box, + }, + /// Ceiling of log base 2. + Log2Ceil { + value: Box, + }, +} + +impl From for Expression { + fn from(value: SimpleExpr) -> Self { + Self::Value(value) + } +} + +impl From for Expression { + fn from(var: Var) -> Self { + Self::Value(var.into()) + } +} + +impl Expression { + /// Evaluates the expression if it contains only constants. + pub fn naive_eval(&self) -> Option { + self.eval_with( + &|value: &SimpleExpr| value.as_constant()?.naive_eval(), + &|_, _| None, + ) + } + + /// Evaluates the expression with custom value and array functions. + pub fn eval_with(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option + where + ValueFn: Fn(&SimpleExpr) -> Option, + ArrayFn: Fn(&SimpleExpr, F) -> Option, + { + match self { + Self::Value(value) => value_fn(value), + Self::ArrayAccess { array, index } => { + array_fn(array, index.eval_with(value_fn, array_fn)?) + } + Self::Binary { + left, + operation, + right, + } => Some(operation.eval( + left.eval_with(value_fn, array_fn)?, + right.eval_with(value_fn, array_fn)?, + )), + Self::Log2Ceil { value } => { + let value = value.eval_with(value_fn, array_fn)?; + Some(F::from_usize(log2_ceil_usize(value.to_usize()))) + } + } + } + + /// Creates a scalar expression. + pub fn scalar(scalar: usize) -> Self { + SimpleExpr::scalar(scalar).into() + } + + /// Creates a zero expression. + pub fn zero() -> Self { + Self::scalar(0) + } +} + +impl Display for Expression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Value(val) => write!(f, "{val}"), + Self::ArrayAccess { array, index } => { + write!(f, "{array}[{index}]") + } + Self::Binary { + left, + operation, + right, + } => { + write!(f, "({left} {operation} {right})") + } + Self::Log2Ceil { value } => { + write!(f, "log2_ceil({value})") + } + } + } +} + +impl Display for SimpleExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Var(var) => write!(f, "{var}"), + Self::Constant(constant) => write!(f, "{constant}"), + Self::ConstMallocAccess { + malloc_label, + offset, + } => { + write!(f, "malloc_access({malloc_label}, {offset})") + } + } + } +} \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/ast/mod.rs b/crates/lean_compiler/src/lang/ast/mod.rs new file mode 100644 index 00000000..d48ffe3f --- /dev/null +++ b/crates/lean_compiler/src/lang/ast/mod.rs @@ -0,0 +1,11 @@ +//! Abstract Syntax Tree definitions for Lean language constructs. + +pub mod program; +pub mod expr; +pub mod stmt; +pub mod types; + +pub use program::*; +pub use expr::*; +pub use stmt::*; +pub use types::*; \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/ast/program.rs b/crates/lean_compiler/src/lang/ast/program.rs new file mode 100644 index 00000000..4ced35a6 --- /dev/null +++ b/crates/lean_compiler/src/lang/ast/program.rs @@ -0,0 +1,86 @@ +//! Program and function definitions. + +use std::collections::BTreeMap; +use std::fmt::{Display, Formatter}; + +use crate::lang::values::Var; + +use super::stmt::Line; + +/// A complete Lean program containing multiple functions. +#[derive(Debug, Clone)] +pub struct Program { + /// Collection of all functions in the program indexed by name. + pub functions: BTreeMap, +} + +/// A function definition with arguments, body, and metadata. +#[derive(Debug, Clone)] +pub struct Function { + /// Function name. + pub name: String, + /// Function arguments with const annotation. + pub arguments: Vec<(Var, bool)>, // (name, is_const) + /// Whether this function should be inlined during compilation. + pub inlined: bool, + /// Number of values returned by this function. + pub n_returned_vars: usize, + /// Function body as a sequence of statements. + pub body: Vec, +} + +impl Function { + /// Returns true if this function has any const arguments. + pub fn has_const_arguments(&self) -> bool { + self.arguments.iter().any(|(_, is_const)| *is_const) + } +} + +impl Display for Program { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for function in self.functions.values() { + if !first { + writeln!(f)?; + } + write!(f, "{function}")?; + first = false; + } + Ok(()) + } +} + +impl Display for Function { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let args_str = self + .arguments + .iter() + .map(|arg| match arg { + (name, true) => format!("const {name}"), + (name, false) => name.to_string(), + }) + .collect::>() + .join(", "); + + let instructions_str = self + .body + .iter() + .map(|line| format!(" {line}")) + .collect::>() + .join("\n"); + + if self.body.is_empty() { + write!( + f, + "fn {}({}) -> {} {{}}", + self.name, args_str, self.n_returned_vars + ) + } else { + write!( + f, + "fn {}({}) -> {} {{\n{}\n}}", + self.name, args_str, self.n_returned_vars, instructions_str + ) + } + } +} \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/ast/stmt.rs b/crates/lean_compiler/src/lang/ast/stmt.rs new file mode 100644 index 00000000..8b30c84a --- /dev/null +++ b/crates/lean_compiler/src/lang/ast/stmt.rs @@ -0,0 +1,246 @@ +//! Statement types for the AST. + +use lean_vm::SourceLineNumber; +use std::fmt::{Display, Formatter}; + +use crate::lang::values::Var; +use crate::precompiles::Precompile; + +use super::{expr::{Expression, SimpleExpr}, types::Boolean}; + +/// A statement in the Lean language. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Line { + Match { + value: Expression, + arms: Vec<(usize, Vec)>, + }, + Assignment { + var: Var, + value: Expression, + }, + ArrayAssign { + array: SimpleExpr, + index: Expression, + value: Expression, + }, + Assert(Boolean), + IfCondition { + condition: Boolean, + then_branch: Vec, + else_branch: Vec, + }, + ForLoop { + iterator: Var, + start: Expression, + end: Expression, + body: Vec, + rev: bool, + unroll: bool, + }, + FunctionCall { + function_name: String, + args: Vec, + return_data: Vec, + }, + FunctionRet { + return_data: Vec, + }, + Precompile { + precompile: Precompile, + args: Vec, + }, + Break, + Panic, + Print { + line_info: String, + content: Vec, + }, + MAlloc { + var: Var, + size: Expression, + vectorized: bool, + vectorized_len: Expression, + }, + DecomposeBits { + var: Var, + to_decompose: Vec, + }, + CounterHint { + var: Var, + }, + LocationReport { + location: SourceLineNumber, + }, +} + +impl Line { + /// Converts the statement to a string with proper indentation. + fn to_string_with_indent(&self, indent: usize) -> String { + let spaces = " ".repeat(indent); + let line_str = match self { + Self::LocationReport { .. } => Default::default(), + Self::Match { value, arms } => { + let arms_str = arms + .iter() + .map(|(const_expr, body)| { + let body_str = body + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + format!("{const_expr} => {{\n{body_str}\n{spaces}}}") + }) + .collect::>() + .join("\n"); + format!("match {value} {{\n{arms_str}\n{spaces}}}") + } + Self::Assignment { var, value } => { + format!("{var} = {value}") + } + Self::ArrayAssign { + array, + index, + value, + } => { + format!("{array}[{index}] = {value}") + } + Self::Assert(condition) => format!("assert {condition}"), + Self::IfCondition { + condition, + then_branch, + else_branch, + } => { + let then_str = then_branch + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + + let else_str = else_branch + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + + if else_branch.is_empty() { + format!("if {condition} {{\n{then_str}\n{spaces}}}") + } else { + format!( + "if {condition} {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" + ) + } + } + Self::CounterHint { var } => { + format!("{var} = counter_hint({var})") + } + Self::ForLoop { + iterator, + start, + end, + body, + rev, + unroll, + } => { + let body_str = body + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + format!( + "for {} in {}{}..{} {}{{\n{}\n{}}}", + iterator, + start, + if *rev { "rev " } else { "" }, + end, + if *unroll { "unroll " } else { "" }, + body_str, + spaces + ) + } + Self::FunctionCall { + function_name, + args, + return_data, + } => { + let args_str = args + .iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", "); + let return_data_str = return_data + .iter() + .map(|var| var.to_string()) + .collect::>() + .join(", "); + + if return_data.is_empty() { + format!("{function_name}({args_str})") + } else { + format!("{return_data_str} = {function_name}({args_str})") + } + } + Self::FunctionRet { return_data } => { + let return_data_str = return_data + .iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", "); + format!("return {return_data_str}") + } + Self::Precompile { precompile, args } => { + format!( + "{}({})", + precompile.name, + args.iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", ") + ) + } + Self::Print { + line_info: _, + content, + } => { + let content_str = content + .iter() + .map(|c| format!("{c}")) + .collect::>() + .join(", "); + format!("print({content_str})") + } + Self::MAlloc { + var, + size, + vectorized, + vectorized_len, + } => { + if *vectorized { + format!("{var} = malloc_vec({size}, {vectorized_len})") + } else { + format!("{var} = malloc({size})") + } + } + Self::DecomposeBits { var, to_decompose } => { + format!( + "{} = decompose_bits({})", + var, + to_decompose + .iter() + .map(|expr| expr.to_string()) + .collect::>() + .join(", ") + ) + } + Self::Break => "break".to_string(), + Self::Panic => "panic".to_string(), + }; + format!("{spaces}{line_str}") + } +} + +impl Display for Line { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_string_with_indent(0)) + } +} \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/ast/types.rs b/crates/lean_compiler/src/lang/ast/types.rs new file mode 100644 index 00000000..336643cf --- /dev/null +++ b/crates/lean_compiler/src/lang/ast/types.rs @@ -0,0 +1,27 @@ +//! Basic type definitions for the AST. + +use std::fmt::{Display, Formatter}; + +use super::expr::Expression; + +/// Boolean condition for assertions and control flow. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Boolean { + /// Equality comparison. + Equal { left: Expression, right: Expression }, + /// Inequality comparison. + Different { left: Expression, right: Expression }, +} + +impl Display for Boolean { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Equal { left, right } => { + write!(f, "{left} == {right}") + } + Self::Different { left, right } => { + write!(f, "{left} != {right}") + } + } + } +} \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/mod.rs b/crates/lean_compiler/src/lang/mod.rs new file mode 100644 index 00000000..5164df60 --- /dev/null +++ b/crates/lean_compiler/src/lang/mod.rs @@ -0,0 +1,7 @@ +//! Language constructs and Abstract Syntax Tree for the Lean compiler. + +pub mod ast; +pub mod values; + +pub use ast::*; +pub use values::*; diff --git a/crates/lean_compiler/src/lang/values/constant.rs b/crates/lean_compiler/src/lang/values/constant.rs new file mode 100644 index 00000000..c9e94545 --- /dev/null +++ b/crates/lean_compiler/src/lang/values/constant.rs @@ -0,0 +1,181 @@ +use lean_vm::Label; +use p3_field::PrimeCharacteristicRing; +use p3_util::log2_ceil_usize; +use std::fmt::{Display, Formatter}; +use utils::ToUsize; + +use crate::{F, ir::HighLevelOperation}; + +/// Constant value types for compile-time computation. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ConstantValue { + Scalar(usize), + PublicInputStart, + PointerToZeroVector, + PointerToOneVector, + FunctionSize { function_name: Label }, + Label(Label), + MatchBlockSize { match_index: usize }, + MatchFirstBlockStart { match_index: usize }, +} + +/// Constant expression that can be evaluated at compile time. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ConstExpression { + Value(ConstantValue), + Binary { + left: Box, + operation: HighLevelOperation, + right: Box, + }, + Log2Ceil { + value: Box, + }, +} + +impl From for ConstExpression { + fn from(value: usize) -> Self { + Self::Value(ConstantValue::Scalar(value)) + } +} + +impl From for ConstExpression { + fn from(value: ConstantValue) -> Self { + Self::Value(value) + } +} + +impl ConstExpression { + /// Creates a zero constant. + pub const fn zero() -> Self { + Self::scalar(0) + } + + /// Creates a one constant. + pub const fn one() -> Self { + Self::scalar(1) + } + + /// Creates a label constant. + pub const fn label(label: Label) -> Self { + Self::Value(ConstantValue::Label(label)) + } + + /// Creates a scalar constant. + pub const fn scalar(scalar: usize) -> Self { + Self::Value(ConstantValue::Scalar(scalar)) + } + + /// Creates a function size constant. + pub const fn function_size(function_name: Label) -> Self { + Self::Value(ConstantValue::FunctionSize { function_name }) + } + + /// Evaluates the constant expression with a custom evaluation function. + pub fn eval_with(&self, func: &EvalFn) -> Option + where + EvalFn: Fn(&ConstantValue) -> Option, + { + match self { + Self::Value(value) => func(value), + Self::Binary { + left, + operation, + right, + } => Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?)), + Self::Log2Ceil { value } => { + let value = value.eval_with(func)?; + Some(F::from_usize(log2_ceil_usize(value.to_usize()))) + } + } + } + + /// Evaluates the expression if it contains only scalar values. + pub fn naive_eval(&self) -> Option { + self.eval_with(&|value| match value { + ConstantValue::Scalar(scalar) => Some(F::from_usize(*scalar)), + _ => None, + }) + } + + /// Simplifies the expression by evaluating scalar subexpressions. + pub fn try_naive_simplification(&self) -> Self { + if let Some(value) = self.naive_eval() { + Self::scalar(value.to_usize()) + } else { + self.clone() + } + } +} + +// For supporting conversion from Expression to ConstExpression +impl TryFrom for ConstExpression { + type Error = (); + + fn try_from(value: crate::lang::ast::Expression) -> Result { + use crate::lang::ast::{Expression, SimpleExpr}; + match value { + Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), + Expression::Value(_) => Err(()), + Expression::ArrayAccess { .. } => Err(()), + Expression::Binary { + left, + operation, + right, + } => { + let left_expr = Self::try_from(*left)?; + let right_expr = Self::try_from(*right)?; + Ok(Self::Binary { + left: Box::new(left_expr), + operation, + right: Box::new(right_expr), + }) + } + Expression::Log2Ceil { value } => { + let value_expr = Self::try_from(*value)?; + Ok(Self::Log2Ceil { + value: Box::new(value_expr), + }) + } + } + } +} + +impl Display for ConstantValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Scalar(scalar) => write!(f, "{scalar}"), + Self::PublicInputStart => write!(f, "@public_input_start"), + Self::PointerToZeroVector => write!(f, "@pointer_to_zero_vector"), + Self::PointerToOneVector => write!(f, "@pointer_to_one_vector"), + Self::FunctionSize { function_name } => { + write!(f, "@function_size_{function_name}") + } + Self::Label(label) => write!(f, "{label}"), + Self::MatchFirstBlockStart { match_index } => { + write!(f, "@match_first_block_start_{match_index}") + } + Self::MatchBlockSize { match_index } => { + write!(f, "@match_block_size_{match_index}") + } + } + } +} + +impl Display for ConstExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Value(value) => write!(f, "{value}"), + Self::Binary { + left, + operation, + right, + } => { + write!(f, "({left} {operation} {right})") + } + Self::Log2Ceil { value } => { + write!(f, "log2_ceil({value})") + } + } + } +} diff --git a/crates/lean_compiler/src/lang/values/mod.rs b/crates/lean_compiler/src/lang/values/mod.rs new file mode 100644 index 00000000..4d67617f --- /dev/null +++ b/crates/lean_compiler/src/lang/values/mod.rs @@ -0,0 +1,7 @@ +//! Value types and constants for the Lean compiler. + +pub mod constant; +pub mod variable; + +pub use constant::*; +pub use variable::*; \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/values/variable.rs b/crates/lean_compiler/src/lang/values/variable.rs new file mode 100644 index 00000000..015797e9 --- /dev/null +++ b/crates/lean_compiler/src/lang/values/variable.rs @@ -0,0 +1,4 @@ +/// Variable identifier type. +pub type Var = String; +/// Constant malloc memory label. +pub type ConstMallocLabel = usize; From 379d72cac9e951f9571763291456e8ac0ba3edde Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 20:51:46 +0200 Subject: [PATCH 02/39] add tests for lang --- crates/lean_compiler/src/lang/ast/expr.rs | 145 ++++++++++++- crates/lean_compiler/src/lang/ast/program.rs | 123 ++++++++++- crates/lean_compiler/src/lang/ast/stmt.rs | 196 +++++++++++++++++- crates/lean_compiler/src/lang/ast/types.rs | 24 +++ .../lean_compiler/src/lang/values/constant.rs | 144 ++++++++++++- 5 files changed, 621 insertions(+), 11 deletions(-) diff --git a/crates/lean_compiler/src/lang/ast/expr.rs b/crates/lean_compiler/src/lang/ast/expr.rs index 2bd75789..dd46505d 100644 --- a/crates/lean_compiler/src/lang/ast/expr.rs +++ b/crates/lean_compiler/src/lang/ast/expr.rs @@ -4,8 +4,8 @@ use p3_field::PrimeCharacteristicRing; use p3_util::log2_ceil_usize; use std::fmt::{Display, Formatter}; +use crate::lang::values::{ConstExpression, ConstMallocLabel, ConstantValue, Var}; use crate::{F, ir::HighLevelOperation}; -use crate::lang::values::{ConstExpression, ConstantValue, Var, ConstMallocLabel}; use utils::ToUsize; /// Simple expression that can be a variable, constant, or memory access. @@ -96,9 +96,7 @@ pub enum Expression { right: Box, }, /// Ceiling of log base 2. - Log2Ceil { - value: Box, - }, + Log2Ceil { value: Box }, } impl From for Expression { @@ -193,4 +191,141 @@ impl Display for SimpleExpr { } } } -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_expr_constants() { + assert!(SimpleExpr::zero().is_constant()); + assert!(SimpleExpr::one().is_constant()); + assert!(SimpleExpr::scalar(42).is_constant()); + assert!(!SimpleExpr::Var("x".to_string()).is_constant()); + } + + #[test] + fn test_simple_expr_as_constant() { + let var = SimpleExpr::Var("x".to_string()); + let constant = SimpleExpr::scalar(5); + + assert_eq!(var.as_constant(), None); + assert_eq!(constant.as_constant(), Some(ConstExpression::scalar(5))); + } + + #[test] + fn test_expression_scalar_creation() { + let expr = Expression::scalar(10); + assert_eq!(expr.naive_eval().unwrap().to_usize(), 10); + } + + #[test] + fn test_simple_expr_display() { + assert_eq!(SimpleExpr::Var("x".to_string()).to_string(), "x"); + assert_eq!(SimpleExpr::scalar(42).to_string(), "42"); + + let malloc_access = SimpleExpr::ConstMallocAccess { + malloc_label: 5, + offset: ConstExpression::scalar(10), + }; + assert_eq!(malloc_access.to_string(), "malloc_access(5, 10)"); + } + + #[test] + fn test_expression_display() { + let var = Expression::Value(SimpleExpr::Var("x".to_string())); + assert_eq!(var.to_string(), "x"); + + let array_access = Expression::ArrayAccess { + array: SimpleExpr::Var("arr".to_string()), + index: Box::new(Expression::scalar(0)), + }; + assert_eq!(array_access.to_string(), "arr[0]"); + + let log2_ceil = Expression::Log2Ceil { + value: Box::new(Expression::scalar(8)), + }; + assert_eq!(log2_ceil.to_string(), "log2_ceil(8)"); + + let binary = Expression::Binary { + left: Box::new(Expression::scalar(5)), + operation: crate::ir::HighLevelOperation::Mul, + right: Box::new(Expression::scalar(2)), + }; + assert_eq!(binary.to_string(), "(5 * 2)"); + } + + #[test] + fn test_expression_eval_with() { + let value_fn = |expr: &SimpleExpr| match expr { + SimpleExpr::Var(name) if name == "x" => Some(F::from_usize(10)), + SimpleExpr::Constant(c) => c.naive_eval(), + _ => None, + }; + let array_fn = |array: &SimpleExpr, index: F| -> Option { + if matches!(array, SimpleExpr::Var(name) if name == "arr") { + Some(F::from_usize(index.to_usize() * 2)) + } else { + None + } + }; + + // Test Value variant + let var_expr = Expression::Value(SimpleExpr::Var("x".to_string())); + assert_eq!( + var_expr.eval_with(&value_fn, &array_fn).unwrap().to_usize(), + 10 + ); + + // Test ArrayAccess variant + let array_expr = Expression::ArrayAccess { + array: SimpleExpr::Var("arr".to_string()), + index: Box::new(Expression::scalar(5)), + }; + assert_eq!( + array_expr + .eval_with(&value_fn, &array_fn) + .unwrap() + .to_usize(), + 10 + ); + + // Test Binary variant + let binary_expr = Expression::Binary { + left: Box::new(Expression::scalar(3)), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::scalar(7)), + }; + assert_eq!( + binary_expr + .eval_with(&value_fn, &array_fn) + .unwrap() + .to_usize(), + 10 + ); + + // Test Log2Ceil variant + let log2_expr = Expression::Log2Ceil { + value: Box::new(Expression::scalar(8)), + }; + assert_eq!( + log2_expr + .eval_with(&value_fn, &array_fn) + .unwrap() + .to_usize(), + 3 + ); + } + + #[test] + fn test_simple_expr_simplify_if_const() { + let var = SimpleExpr::Var("x".to_string()); + let simplified_var = var.simplify_if_const(); + assert_eq!(simplified_var, var); + + let constant = SimpleExpr::scalar(42); + let simplified_constant = constant.simplify_if_const(); + assert_eq!(simplified_constant, SimpleExpr::scalar(42)); + } +} diff --git a/crates/lean_compiler/src/lang/ast/program.rs b/crates/lean_compiler/src/lang/ast/program.rs index 4ced35a6..e073d92d 100644 --- a/crates/lean_compiler/src/lang/ast/program.rs +++ b/crates/lean_compiler/src/lang/ast/program.rs @@ -83,4 +83,125 @@ impl Display for Function { ) } } -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{Expression, SimpleExpr}; + + #[test] + fn test_function_has_const_arguments() { + let func_with_const = Function { + name: "test".to_string(), + arguments: vec![("x".to_string(), false), ("y".to_string(), true)], + inlined: false, + n_returned_vars: 1, + body: vec![], + }; + assert!(func_with_const.has_const_arguments()); + + let func_no_const = Function { + name: "test".to_string(), + arguments: vec![("x".to_string(), false), ("y".to_string(), false)], + inlined: false, + n_returned_vars: 1, + body: vec![], + }; + assert!(!func_no_const.has_const_arguments()); + } + + #[test] + fn test_function_display_empty_body() { + let func = Function { + name: "empty_fn".to_string(), + arguments: vec![("x".to_string(), false)], + inlined: false, + n_returned_vars: 0, + body: vec![], + }; + assert_eq!(func.to_string(), "fn empty_fn(x) -> 0 {}"); + } + + #[test] + fn test_function_display_with_const_args() { + let func = Function { + name: "const_fn".to_string(), + arguments: vec![("x".to_string(), true), ("y".to_string(), false)], + inlined: false, + n_returned_vars: 1, + body: vec![], + }; + assert_eq!(func.to_string(), "fn const_fn(const x, y) -> 1 {}"); + } + + #[test] + fn test_program_display() { + let mut program = Program { + functions: BTreeMap::new(), + }; + + let func = Function { + name: "test".to_string(), + arguments: vec![], + inlined: false, + n_returned_vars: 0, + body: vec![], + }; + + program.functions.insert("test".to_string(), func); + assert_eq!(program.to_string(), "fn test() -> 0 {}"); + } + + #[test] + fn test_program_multiple_functions_display() { + let mut program = Program { + functions: BTreeMap::new(), + }; + + let func1 = Function { + name: "func1".to_string(), + arguments: vec![], + inlined: false, + n_returned_vars: 0, + body: vec![], + }; + + let func2 = Function { + name: "func2".to_string(), + arguments: vec![], + inlined: false, + n_returned_vars: 1, + body: vec![], + }; + + program.functions.insert("func1".to_string(), func1); + program.functions.insert("func2".to_string(), func2); + + let result = program.to_string(); + assert_eq!(result, "fn func1() -> 0 {}\nfn func2() -> 1 {}"); + } + + #[test] + fn test_function_display_with_body() { + let func = Function { + name: "test_func".to_string(), + arguments: vec![("x".to_string(), false)], + inlined: false, + n_returned_vars: 1, + body: vec![ + Line::Assignment { + var: "result".to_string(), + value: Expression::scalar(42), + }, + Line::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Var("result".to_string()))], + }, + ], + }; + assert_eq!( + func.to_string(), + "fn test_func(x) -> 1 {\n result = 42\n return result\n}" + ); + } +} diff --git a/crates/lean_compiler/src/lang/ast/stmt.rs b/crates/lean_compiler/src/lang/ast/stmt.rs index 8b30c84a..3203560c 100644 --- a/crates/lean_compiler/src/lang/ast/stmt.rs +++ b/crates/lean_compiler/src/lang/ast/stmt.rs @@ -6,7 +6,10 @@ use std::fmt::{Display, Formatter}; use crate::lang::values::Var; use crate::precompiles::Precompile; -use super::{expr::{Expression, SimpleExpr}, types::Boolean}; +use super::{ + expr::{Expression, SimpleExpr}, + types::Boolean, +}; /// A statement in the Lean language. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -148,11 +151,11 @@ impl Line { .collect::>() .join("\n"); format!( - "for {} in {}{}..{} {}{{\n{}\n{}}}", + "for {} in {}..{} {}{}{{\n{}\n{}}}", iterator, start, - if *rev { "rev " } else { "" }, end, + if *rev { "rev " } else { "" }, if *unroll { "unroll " } else { "" }, body_str, spaces @@ -243,4 +246,189 @@ impl Display for Line { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.to_string_with_indent(0)) } -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_line_assignment_display() { + let assignment = Line::Assignment { + var: "x".to_string(), + value: Expression::scalar(42), + }; + assert_eq!(assignment.to_string(), "x = 42"); + } + + #[test] + fn test_line_array_assign_display() { + let array_assign = Line::ArrayAssign { + array: SimpleExpr::Var("arr".to_string()), + index: Expression::scalar(0), + value: Expression::scalar(10), + }; + assert_eq!(array_assign.to_string(), "arr[0] = 10"); + } + + #[test] + fn test_line_break_panic_display() { + assert_eq!(Line::Break.to_string(), "break"); + assert_eq!(Line::Panic.to_string(), "panic"); + } + + #[test] + fn test_line_malloc_display() { + let malloc = Line::MAlloc { + var: "ptr".to_string(), + size: Expression::scalar(100), + vectorized: false, + vectorized_len: Expression::scalar(1), + }; + assert_eq!(malloc.to_string(), "ptr = malloc(100)"); + + let malloc_vec = Line::MAlloc { + var: "ptr".to_string(), + size: Expression::scalar(100), + vectorized: true, + vectorized_len: Expression::scalar(8), + }; + assert_eq!(malloc_vec.to_string(), "ptr = malloc_vec(100, 8)"); + } + + #[test] + fn test_line_function_call_display() { + let call = Line::FunctionCall { + function_name: "test_fn".to_string(), + args: vec![Expression::scalar(1), Expression::scalar(2)], + return_data: vec!["result".to_string()], + }; + assert_eq!(call.to_string(), "result = test_fn(1, 2)"); + + let call_no_return = Line::FunctionCall { + function_name: "void_fn".to_string(), + args: vec![Expression::scalar(42)], + return_data: vec![], + }; + assert_eq!(call_no_return.to_string(), "void_fn(42)"); + } + + #[test] + fn test_line_return_display() { + let ret = Line::FunctionRet { + return_data: vec![Expression::scalar(1), Expression::scalar(2)], + }; + assert_eq!(ret.to_string(), "return 1, 2"); + } + + #[test] + fn test_line_assert_display() { + let assert_stmt = Line::Assert(Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::scalar(10), + }); + assert_eq!(assert_stmt.to_string(), "assert x == 10"); + } + + #[test] + fn test_line_counter_hint_display() { + let hint = Line::CounterHint { + var: "counter".to_string(), + }; + assert_eq!(hint.to_string(), "counter = counter_hint(counter)"); + } + + #[test] + fn test_line_print_display() { + let print = Line::Print { + line_info: "debug".to_string(), + content: vec![ + Expression::scalar(42), + Expression::Value(SimpleExpr::Var("x".to_string())), + ], + }; + assert_eq!(print.to_string(), "print(42, x)"); + } + + #[test] + fn test_line_decompose_bits_display() { + let decompose = Line::DecomposeBits { + var: "bits".to_string(), + to_decompose: vec![ + Expression::scalar(255), + Expression::Value(SimpleExpr::Var("y".to_string())), + ], + }; + assert_eq!(decompose.to_string(), "bits = decompose_bits(255, y)"); + } + + #[test] + fn test_line_for_loop_display() { + let for_loop = Line::ForLoop { + iterator: "i".to_string(), + start: Expression::scalar(0), + end: Expression::scalar(10), + body: vec![Line::Break], + rev: false, + unroll: false, + }; + assert_eq!(for_loop.to_string(), "for i in 0..10 {\n break\n}"); + + let for_loop_rev_unroll = Line::ForLoop { + iterator: "i".to_string(), + start: Expression::scalar(0), + end: Expression::scalar(5), + body: vec![], + rev: true, + unroll: true, + }; + assert_eq!( + for_loop_rev_unroll.to_string(), + "for i in 0..5 rev unroll {\n\n}" + ); + } + + #[test] + fn test_line_if_condition_display() { + let if_simple = Line::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::scalar(0), + }, + then_branch: vec![Line::Panic], + else_branch: vec![], + }; + assert_eq!(if_simple.to_string(), "if x == 0 {\n panic\n}"); + + let if_else = Line::IfCondition { + condition: Boolean::Different { + left: Expression::scalar(1), + right: Expression::scalar(2), + }, + then_branch: vec![Line::Break], + else_branch: vec![Line::Panic], + }; + assert_eq!( + if_else.to_string(), + "if 1 != 2 {\n break\n} else {\n panic\n}" + ); + } + + #[test] + fn test_line_match_display() { + let match_stmt = Line::Match { + value: Expression::Value(SimpleExpr::Var("x".to_string())), + arms: vec![(0, vec![Line::Break]), (1, vec![Line::Panic])], + }; + assert_eq!( + match_stmt.to_string(), + "match x {\n0 => {\n break\n}\n1 => {\n panic\n}\n}" + ); + } + + #[test] + fn test_line_location_report_display() { + let location = Line::LocationReport { location: 42 }; + assert_eq!(location.to_string(), ""); + } +} diff --git a/crates/lean_compiler/src/lang/ast/types.rs b/crates/lean_compiler/src/lang/ast/types.rs index 336643cf..d7dff68e 100644 --- a/crates/lean_compiler/src/lang/ast/types.rs +++ b/crates/lean_compiler/src/lang/ast/types.rs @@ -24,4 +24,28 @@ impl Display for Boolean { } } } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::SimpleExpr; + + #[test] + fn test_boolean_equal_display() { + let equal = Boolean::Equal { + left: Expression::scalar(5), + right: Expression::scalar(10), + }; + assert_eq!(equal.to_string(), "5 == 10"); + } + + #[test] + fn test_boolean_different_display() { + let different = Boolean::Different { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::scalar(0), + }; + assert_eq!(different.to_string(), "x != 0"); + } } \ No newline at end of file diff --git a/crates/lean_compiler/src/lang/values/constant.rs b/crates/lean_compiler/src/lang/values/constant.rs index c9e94545..74f64a07 100644 --- a/crates/lean_compiler/src/lang/values/constant.rs +++ b/crates/lean_compiler/src/lang/values/constant.rs @@ -108,7 +108,6 @@ impl ConstExpression { } } -// For supporting conversion from Expression to ConstExpression impl TryFrom for ConstExpression { type Error = (); @@ -179,3 +178,146 @@ impl Display for ConstExpression { } } } + +#[cfg(test)] +mod tests { + use super::*; + use lean_vm::Label; + + #[test] + fn test_const_expression_constructors() { + let zero = ConstExpression::zero(); + let one = ConstExpression::one(); + let scalar = ConstExpression::scalar(42); + + assert_eq!(zero.naive_eval().unwrap().to_usize(), 0); + assert_eq!(one.naive_eval().unwrap().to_usize(), 1); + assert_eq!(scalar.naive_eval().unwrap().to_usize(), 42); + } + + #[test] + fn test_const_expression_simplification() { + let expr = ConstExpression::scalar(10); + let simplified = expr.try_naive_simplification(); + assert_eq!(simplified.naive_eval().unwrap().to_usize(), 10); + } + + #[test] + fn test_constant_value_display() { + assert_eq!(ConstantValue::Scalar(42).to_string(), "42"); + assert_eq!( + ConstantValue::PublicInputStart.to_string(), + "@public_input_start" + ); + assert_eq!( + ConstantValue::PointerToZeroVector.to_string(), + "@pointer_to_zero_vector" + ); + assert_eq!( + ConstantValue::PointerToOneVector.to_string(), + "@pointer_to_one_vector" + ); + assert_eq!( + ConstantValue::MatchBlockSize { match_index: 5 }.to_string(), + "@match_block_size_5" + ); + assert_eq!( + ConstantValue::MatchFirstBlockStart { match_index: 3 }.to_string(), + "@match_first_block_start_3" + ); + + let label = Label::function("test"); + assert_eq!( + ConstantValue::FunctionSize { function_name: label.clone() }.to_string(), + format!("@function_size_{}", label) + ); + } + + #[test] + fn test_const_expression_display() { + let value = ConstExpression::Value(ConstantValue::Scalar(42)); + assert_eq!(value.to_string(), "42"); + + let log2_expr = ConstExpression::Log2Ceil { + value: Box::new(ConstExpression::scalar(8)), + }; + assert_eq!(log2_expr.to_string(), "log2_ceil(8)"); + + let binary_expr = ConstExpression::Binary { + left: Box::new(ConstExpression::scalar(5)), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(ConstExpression::scalar(3)), + }; + assert_eq!(binary_expr.to_string(), "(5 + 3)"); + } + + #[test] + fn test_const_expression_from_usize() { + let expr = ConstExpression::from(100); + assert_eq!(expr.naive_eval().unwrap().to_usize(), 100); + } + + #[test] + fn test_constant_value_label() { + let label = Label::function("test_func"); + let const_val = ConstantValue::Label(label.clone()); + assert_eq!(const_val.to_string(), label.to_string()); + } + + #[test] + fn test_const_expression_eval_with() { + let eval_fn = |value: &ConstantValue| match value { + ConstantValue::Scalar(n) => Some(F::from_usize(*n)), + ConstantValue::PublicInputStart => Some(F::from_usize(1000)), + _ => None, + }; + + // Test Value variant + let scalar = ConstExpression::Value(ConstantValue::Scalar(42)); + assert_eq!(scalar.eval_with(&eval_fn).unwrap().to_usize(), 42); + + // Test Binary variant + let binary = ConstExpression::Binary { + left: Box::new(ConstExpression::scalar(5)), + operation: crate::ir::HighLevelOperation::Mul, + right: Box::new(ConstExpression::scalar(3)), + }; + assert_eq!(binary.eval_with(&eval_fn).unwrap().to_usize(), 15); + + // Test Log2Ceil variant + let log2 = ConstExpression::Log2Ceil { + value: Box::new(ConstExpression::scalar(16)), + }; + assert_eq!(log2.eval_with(&eval_fn).unwrap().to_usize(), 4); + } + + #[test] + fn test_const_expression_function_size() { + let label = Label::function("test_fn"); + let func_size = ConstExpression::function_size(label.clone()); + assert_eq!(func_size, ConstExpression::Value(ConstantValue::FunctionSize { function_name: label })); + } + + #[test] + fn test_const_expression_label() { + let label = Label::function("main"); + let label_expr = ConstExpression::label(label.clone()); + assert_eq!(label_expr, ConstExpression::Value(ConstantValue::Label(label))); + } + + #[test] + fn test_const_expression_try_from() { + use crate::lang::ast::{Expression, SimpleExpr}; + + // Test successful conversion + let const_expr = ConstExpression::scalar(10); + let expr = Expression::Value(SimpleExpr::Constant(const_expr.clone())); + let result = ConstExpression::try_from(expr); + assert_eq!(result, Ok(const_expr)); + + // Test failed conversion + let var_expr = Expression::Value(SimpleExpr::Var("x".to_string())); + let result = ConstExpression::try_from(var_expr); + assert_eq!(result, Err(())); + } +} From 345c2a7b18a3062277963b52ac1d984b11dec33a Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 20:54:10 +0200 Subject: [PATCH 03/39] one more test and fmt --- crates/lean_compiler/src/lang/ast/mod.rs | 6 +++--- crates/lean_compiler/src/lang/ast/stmt.rs | 9 +++++++++ crates/lean_compiler/src/lang/ast/types.rs | 2 +- .../lean_compiler/src/lang/values/constant.rs | 17 ++++++++++++++--- crates/lean_compiler/src/lang/values/mod.rs | 2 +- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/crates/lean_compiler/src/lang/ast/mod.rs b/crates/lean_compiler/src/lang/ast/mod.rs index d48ffe3f..b9992d21 100644 --- a/crates/lean_compiler/src/lang/ast/mod.rs +++ b/crates/lean_compiler/src/lang/ast/mod.rs @@ -1,11 +1,11 @@ //! Abstract Syntax Tree definitions for Lean language constructs. -pub mod program; pub mod expr; +pub mod program; pub mod stmt; pub mod types; -pub use program::*; pub use expr::*; +pub use program::*; pub use stmt::*; -pub use types::*; \ No newline at end of file +pub use types::*; diff --git a/crates/lean_compiler/src/lang/ast/stmt.rs b/crates/lean_compiler/src/lang/ast/stmt.rs index 3203560c..b1cb3fa9 100644 --- a/crates/lean_compiler/src/lang/ast/stmt.rs +++ b/crates/lean_compiler/src/lang/ast/stmt.rs @@ -431,4 +431,13 @@ mod tests { let location = Line::LocationReport { location: 42 }; assert_eq!(location.to_string(), ""); } + + #[test] + fn test_line_precompile_display() { + let precompile_line = Line::Precompile { + precompile: crate::precompiles::POSEIDON_16, + args: vec![Expression::scalar(1), Expression::scalar(2)], + }; + assert_eq!(precompile_line.to_string(), "poseidon16(1, 2)"); + } } diff --git a/crates/lean_compiler/src/lang/ast/types.rs b/crates/lean_compiler/src/lang/ast/types.rs index d7dff68e..c795e9ac 100644 --- a/crates/lean_compiler/src/lang/ast/types.rs +++ b/crates/lean_compiler/src/lang/ast/types.rs @@ -48,4 +48,4 @@ mod tests { }; assert_eq!(different.to_string(), "x != 0"); } -} \ No newline at end of file +} diff --git a/crates/lean_compiler/src/lang/values/constant.rs b/crates/lean_compiler/src/lang/values/constant.rs index 74f64a07..0cdb11d3 100644 --- a/crates/lean_compiler/src/lang/values/constant.rs +++ b/crates/lean_compiler/src/lang/values/constant.rs @@ -228,7 +228,10 @@ mod tests { let label = Label::function("test"); assert_eq!( - ConstantValue::FunctionSize { function_name: label.clone() }.to_string(), + ConstantValue::FunctionSize { + function_name: label.clone() + } + .to_string(), format!("@function_size_{}", label) ); } @@ -295,14 +298,22 @@ mod tests { fn test_const_expression_function_size() { let label = Label::function("test_fn"); let func_size = ConstExpression::function_size(label.clone()); - assert_eq!(func_size, ConstExpression::Value(ConstantValue::FunctionSize { function_name: label })); + assert_eq!( + func_size, + ConstExpression::Value(ConstantValue::FunctionSize { + function_name: label + }) + ); } #[test] fn test_const_expression_label() { let label = Label::function("main"); let label_expr = ConstExpression::label(label.clone()); - assert_eq!(label_expr, ConstExpression::Value(ConstantValue::Label(label))); + assert_eq!( + label_expr, + ConstExpression::Value(ConstantValue::Label(label)) + ); } #[test] diff --git a/crates/lean_compiler/src/lang/values/mod.rs b/crates/lean_compiler/src/lang/values/mod.rs index 4d67617f..a607ae06 100644 --- a/crates/lean_compiler/src/lang/values/mod.rs +++ b/crates/lean_compiler/src/lang/values/mod.rs @@ -4,4 +4,4 @@ pub mod constant; pub mod variable; pub use constant::*; -pub use variable::*; \ No newline at end of file +pub use variable::*; From 86c574f9e1caa9165ed8748c9d4e0730fa7801b4 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 21:01:37 +0200 Subject: [PATCH 04/39] add tests for parser --- crates/lean_compiler/src/parser/error.rs | 60 ++++++++ crates/lean_compiler/src/parser/grammar.rs | 49 +++++++ crates/lean_compiler/src/parser/mod.rs | 56 ++++++++ .../src/parser/parsers/literal.rs | 128 ++++++++++++++++++ .../lean_compiler/src/parser/parsers/mod.rs | 118 ++++++++++++++++ 5 files changed, 411 insertions(+) diff --git a/crates/lean_compiler/src/parser/error.rs b/crates/lean_compiler/src/parser/error.rs index c4bc77c9..1b51f603 100644 --- a/crates/lean_compiler/src/parser/error.rs +++ b/crates/lean_compiler/src/parser/error.rs @@ -80,3 +80,63 @@ impl std::error::Error for ParseError {} /// Result type for parsing operations. pub type ParseResult = Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_semantic_error_new() { + let error = SemanticError::new("test message"); + assert_eq!(error.message, "test message"); + assert_eq!(error.context, None); + } + + #[test] + fn test_semantic_error_with_context() { + let error = SemanticError::with_context("test message", "test context"); + assert_eq!(error.message, "test message"); + assert_eq!(error.context, Some("test context".to_string())); + } + + #[test] + fn test_semantic_error_display() { + let error = SemanticError::new("test message"); + assert_eq!(error.to_string(), "test message"); + + let error_with_context = SemanticError::with_context("test message", "test context"); + assert_eq!( + error_with_context.to_string(), + "test message (in test context)" + ); + } + + #[test] + fn test_parse_error_from_string() { + let error: ParseError = "test message".to_string().into(); + if let ParseError::SemanticError(semantic_error) = error { + assert_eq!(semantic_error.message, "test message"); + assert_eq!(semantic_error.context, None); + } else { + panic!("Expected SemanticError"); + } + } + + #[test] + fn test_parse_error_from_semantic_error() { + let semantic_error = SemanticError::new("test"); + let parse_error: ParseError = semantic_error.into(); + if let ParseError::SemanticError(se) = parse_error { + assert_eq!(se.message, "test"); + } else { + panic!("Expected SemanticError variant"); + } + } + + #[test] + fn test_parse_error_display() { + let semantic_error = SemanticError::new("semantic test"); + let parse_error = ParseError::SemanticError(semantic_error); + assert_eq!(parse_error.to_string(), "Semantic error: semantic test"); + } +} diff --git a/crates/lean_compiler/src/parser/grammar.rs b/crates/lean_compiler/src/parser/grammar.rs index 90908c32..2e969dc6 100644 --- a/crates/lean_compiler/src/parser/grammar.rs +++ b/crates/lean_compiler/src/parser/grammar.rs @@ -34,3 +34,52 @@ pub fn parse_source(input: &str) -> Result, Box Result<(Program, BTreeMap), let mut ctx = ParseContext::new(); ProgramParser::parse(program_pair, &mut ctx) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_program_simple() { + let input = "fn main() {}"; + let result = parse_program(input); + assert!(result.is_ok()); + + if let Ok((program, locations)) = result { + assert!(program.functions.contains_key("main")); + assert!(!locations.is_empty()); + } + } + + #[test] + fn test_parse_program_with_comments() { + let input = r#" + // This is a comment + fn main() { + // Another comment + } + "#; + let result = parse_program(input); + assert!(result.is_ok()); + + if let Ok((program, _)) = result { + assert!(program.functions.contains_key("main")); + } + } + + #[test] + fn test_parse_program_invalid_syntax() { + let input = "invalid syntax $%@"; + let result = parse_program(input); + assert!(result.is_err()); + } + + #[test] + fn test_parse_program_multiple_functions() { + let input = r#" + fn first() {} + fn second() {} + "#; + let result = parse_program(input); + assert!(result.is_ok()); + + if let Ok((program, locations)) = result { + assert!(program.functions.contains_key("first")); + assert!(program.functions.contains_key("second")); + assert_eq!(locations.len(), 2); + } + } +} diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index 7b54a8d2..210c38fc 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -153,3 +153,131 @@ impl Parse> for VarListParser { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{ConstExpression, ConstantValue}; + use crate::parser::grammar::{LangParser, Rule}; + use pest::Parser; + + #[test] + fn test_var_or_constant_parser_identifier() { + let mut ctx = ParseContext::new(); + let input = "test_var"; + let mut pairs = LangParser::parse(Rule::identifier, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx).unwrap(); + if let SimpleExpr::Var(name) = result { + assert_eq!(name, "test_var"); + } else { + panic!("Expected variable"); + } + } + + #[test] + fn test_var_or_constant_parser_numeric_literal() { + let mut ctx = ParseContext::new(); + let input = "42"; + let mut pairs = LangParser::parse(Rule::constant_value, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx).unwrap(); + if let SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(value))) = result { + assert_eq!(value, 42); + } else { + panic!("Expected scalar constant"); + } + } + + #[test] + fn test_var_or_constant_parser_defined_constant() { + let mut ctx = ParseContext::new(); + ctx.add_constant("MY_CONST".to_string(), 100).unwrap(); + + let input = "MY_CONST"; + let mut pairs = LangParser::parse(Rule::identifier, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx).unwrap(); + if let SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(value))) = result { + assert_eq!(value, 100); + } else { + panic!("Expected scalar constant"); + } + } + + #[test] + fn test_var_or_constant_parser_public_input_start() { + let mut ctx = ParseContext::new(); + let input = "public_input_start"; + let mut pairs = LangParser::parse(Rule::constant_value, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx).unwrap(); + if let SimpleExpr::Constant(ConstExpression::Value(ConstantValue::PublicInputStart)) = + result + { + // Success + } else { + panic!("Expected PublicInputStart constant"); + } + } + + #[test] + fn test_var_or_constant_parser_pointer_to_zero_vector() { + let mut ctx = ParseContext::new(); + let input = "pointer_to_zero_vector"; + let mut pairs = LangParser::parse(Rule::identifier, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx).unwrap(); + if let SimpleExpr::Constant(ConstExpression::Value(ConstantValue::PointerToZeroVector)) = + result + { + // Success + } else { + panic!("Expected PointerToZeroVector constant"); + } + } + + #[test] + fn test_var_or_constant_parser_pointer_to_one_vector() { + let mut ctx = ParseContext::new(); + let input = "pointer_to_one_vector"; + let mut pairs = LangParser::parse(Rule::identifier, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx).unwrap(); + if let SimpleExpr::Constant(ConstExpression::Value(ConstantValue::PointerToOneVector)) = + result + { + // Success + } else { + panic!("Expected PointerToOneVector constant"); + } + } + + #[test] + fn test_const_expr_parser_numeric() { + let mut ctx = ParseContext::new(); + let input = "123"; + let mut pairs = LangParser::parse(Rule::pattern, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ConstExprParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, 123); + } + + #[test] + fn test_const_expr_parser_public_input_start_error() { + let mut ctx = ParseContext::new(); + let input = "public_input_start"; + let mut pairs = LangParser::parse(Rule::pattern, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ConstExprParser::parse(pair, &mut ctx); + assert!(result.is_err()); + } +} diff --git a/crates/lean_compiler/src/parser/parsers/mod.rs b/crates/lean_compiler/src/parser/parsers/mod.rs index 956d860b..6d5ca03f 100644 --- a/crates/lean_compiler/src/parser/parsers/mod.rs +++ b/crates/lean_compiler/src/parser/parsers/mod.rs @@ -86,3 +86,121 @@ pub fn next_inner_pair<'i>( .next() .ok_or_else(|| SemanticError::with_context("Unexpected end of input", context).into()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::grammar::Rule; + + #[test] + fn test_parse_context_new() { + let ctx = ParseContext::new(); + assert!(ctx.constants.is_empty()); + assert_eq!(ctx.trash_var_count, 0); + } + + #[test] + fn test_parse_context_default() { + let ctx = ParseContext::default(); + assert!(ctx.constants.is_empty()); + assert_eq!(ctx.trash_var_count, 0); + } + + #[test] + fn test_add_constant_success() { + let mut ctx = ParseContext::new(); + let result = ctx.add_constant("test".to_string(), 42); + assert!(result.is_ok()); + assert_eq!(ctx.get_constant("test"), Some(42)); + } + + #[test] + fn test_add_constant_duplicate() { + let mut ctx = ParseContext::new(); + ctx.add_constant("test".to_string(), 42).unwrap(); + let result = ctx.add_constant("test".to_string(), 24); + assert!(result.is_err()); + if let Err(error) = result { + assert!(error.message.contains("Multiply defined constant")); + } + } + + #[test] + fn test_get_constant_exists() { + let mut ctx = ParseContext::new(); + ctx.add_constant("test".to_string(), 42).unwrap(); + assert_eq!(ctx.get_constant("test"), Some(42)); + } + + #[test] + fn test_get_constant_not_exists() { + let ctx = ParseContext::new(); + assert_eq!(ctx.get_constant("missing"), None); + } + + #[test] + fn test_next_trash_var() { + let mut ctx = ParseContext::new(); + let first = ctx.next_trash_var(); + let second = ctx.next_trash_var(); + + assert_eq!(first, "@trash_1"); + assert_eq!(second, "@trash_2"); + assert_eq!(ctx.trash_var_count, 2); + } + + #[test] + fn test_expect_rule_success() { + use crate::parser::grammar::LangParser; + use pest::Parser; + + let input = "fn main() {}"; + let mut pairs = LangParser::parse(Rule::program, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = expect_rule(&pair, Rule::program); + assert!(result.is_ok()); + } + + #[test] + fn test_expect_rule_failure() { + use crate::parser::grammar::LangParser; + use pest::Parser; + + let input = "fn main() {}"; + let mut pairs = LangParser::parse(Rule::program, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = expect_rule(&pair, Rule::function); + assert!(result.is_err()); + } + + #[test] + fn test_next_inner_pair_success() { + use crate::parser::grammar::LangParser; + use pest::Parser; + + let input = "fn main() {}"; + let mut pairs = LangParser::parse(Rule::program, input).unwrap(); + let pair = pairs.next().unwrap(); + let mut inner = pair.into_inner(); + + let result = next_inner_pair(&mut inner, "test context"); + assert!(result.is_ok()); + } + + #[test] + fn test_next_inner_pair_failure() { + let mut empty_iter = std::iter::empty(); + let result = next_inner_pair(&mut empty_iter, "test context"); + assert!(result.is_err()); + if let Err(error) = result { + if let crate::parser::error::ParseError::SemanticError(se) = error { + assert_eq!(se.message, "Unexpected end of input"); + assert_eq!(se.context, Some("test context".to_string())); + } else { + panic!("Expected SemanticError"); + } + } + } +} From e6e0c402086e7f7188e13791d93379663cdbc3e2 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 21:07:51 +0200 Subject: [PATCH 05/39] small ctx fix --- crates/lean_compiler/src/parser/parsers/program.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/lean_compiler/src/parser/parsers/program.rs b/crates/lean_compiler/src/parser/parsers/program.rs index 5e23e15e..606c41ee 100644 --- a/crates/lean_compiler/src/parser/parsers/program.rs +++ b/crates/lean_compiler/src/parser/parsers/program.rs @@ -16,21 +16,20 @@ pub struct ProgramParser; impl Parse<(Program, BTreeMap)> for ProgramParser { fn parse( pair: ParsePair<'_>, - _ctx: &mut ParseContext, + ctx: &mut ParseContext, ) -> ParseResult<(Program, BTreeMap)> { - let mut ctx = ParseContext::new(); let mut functions = BTreeMap::new(); let mut function_locations = BTreeMap::new(); for item in pair.into_inner() { match item.as_rule() { Rule::constant_declaration => { - let (name, value) = ConstantDeclarationParser::parse(item, &mut ctx)?; + let (name, value) = ConstantDeclarationParser::parse(item, ctx)?; ctx.add_constant(name, value)?; } Rule::function => { let location = item.line_col().0; - let function = FunctionParser::parse(item, &mut ctx)?; + let function = FunctionParser::parse(item, ctx)?; let name = function.name.clone(); function_locations.insert(location, name.clone()); From f09074763575379a044d3077dff75a4a5621c273 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 21:15:34 +0200 Subject: [PATCH 06/39] more tests and small bug fixes --- crates/lean_compiler/src/parser/grammar.rs | 23 ++- .../src/parser/parsers/function.rs | 135 +++++++++++++++++- .../src/parser/parsers/literal.rs | 30 +++- .../src/parser/parsers/statement.rs | 121 +++++++++++++++- 4 files changed, 300 insertions(+), 9 deletions(-) diff --git a/crates/lean_compiler/src/parser/grammar.rs b/crates/lean_compiler/src/parser/grammar.rs index 2e969dc6..5829e12a 100644 --- a/crates/lean_compiler/src/parser/grammar.rs +++ b/crates/lean_compiler/src/parser/grammar.rs @@ -32,7 +32,14 @@ pub fn next_inner<'i>( /// Utility function to parse the main program structure. pub fn parse_source(input: &str) -> Result, Box>> { let mut pairs = LangParser::parse(Rule::program, input)?; - Ok(pairs.next().unwrap()) + pairs.next().ok_or_else(|| { + Box::new(pest::error::Error::new_from_pos( + pest::error::ErrorVariant::CustomError { + message: "No program found in input".to_string(), + }, + pest::Position::from_start(input), + )) + }) } #[cfg(test)] @@ -82,4 +89,18 @@ mod tests { let result = parse_source(input); assert!(result.is_err()); } + + #[test] + fn test_parse_source_empty_input() { + let input = ""; + let result = parse_source(input); + assert!(result.is_err()); + } + + #[test] + fn test_parse_source_whitespace_only() { + let input = " \n\t "; + let result = parse_source(input); + assert!(result.is_err()); + } } diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 40e126cc..5881bd71 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -171,12 +171,18 @@ impl FunctionCallParser { }) } "malloc_vec" => { + if return_data.len() != 1 { + return Err( + SemanticError::new("malloc_vec must return exactly one value").into(), + ); + } + let vectorized_len = if args.len() == 1 { Expression::scalar(LOG_VECTOR_LEN) } else if args.len() == 2 { args[1].clone() } else { - return Err(SemanticError::new("Invalid malloc_vec call").into()); + return Err(SemanticError::new("malloc_vec takes 1 or 2 arguments").into()); }; Ok(Line::MAlloc { @@ -257,3 +263,130 @@ impl Parse> for TupleExpressionParser { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{Expression, Line}; + + #[test] + fn test_malloc_vec_no_return_data() { + let args = vec![Expression::scalar(100)]; + let return_data = vec![]; + + let result = FunctionCallParser::handle_builtin_function( + "malloc_vec".to_string(), + args, + return_data, + ); + + assert!(result.is_err()); + if let Err(crate::parser::error::ParseError::SemanticError(error)) = result { + assert!( + error + .message + .contains("malloc_vec must return exactly one value") + ); + } else { + panic!("Expected SemanticError"); + } + } + + #[test] + fn test_malloc_vec_too_many_return_values() { + let args = vec![Expression::scalar(100)]; + let return_data = vec!["ptr1".to_string(), "ptr2".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "malloc_vec".to_string(), + args, + return_data, + ); + + assert!(result.is_err()); + if let Err(crate::parser::error::ParseError::SemanticError(error)) = result { + assert!( + error + .message + .contains("malloc_vec must return exactly one value") + ); + } else { + panic!("Expected SemanticError"); + } + } + + #[test] + fn test_malloc_vec_valid_one_arg() { + let args = vec![Expression::scalar(100)]; + let return_data = vec!["ptr".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "malloc_vec".to_string(), + args, + return_data, + ) + .unwrap(); + + if let Line::MAlloc { + var, + size: _, + vectorized, + vectorized_len: _, + } = result + { + assert_eq!(var, "ptr"); + assert!(vectorized); + } else { + panic!("Expected MAlloc line"); + } + } + + #[test] + fn test_malloc_vec_valid_two_args() { + let args = vec![Expression::scalar(100), Expression::scalar(8)]; + let return_data = vec!["ptr".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "malloc_vec".to_string(), + args, + return_data, + ) + .unwrap(); + + if let Line::MAlloc { + var, + size: _, + vectorized, + vectorized_len: _, + } = result + { + assert_eq!(var, "ptr"); + assert!(vectorized); + } else { + panic!("Expected MAlloc line"); + } + } + + #[test] + fn test_malloc_vec_too_many_args() { + let args = vec![ + Expression::scalar(100), + Expression::scalar(8), + Expression::scalar(16), + ]; + let return_data = vec!["ptr".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "malloc_vec".to_string(), + args, + return_data, + ); + + assert!(result.is_err()); + if let Err(crate::parser::error::ParseError::SemanticError(error)) = result { + assert!(error.message.contains("malloc_vec takes 1 or 2 arguments")); + } else { + panic!("Expected SemanticError"); + } + } +} diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index 210c38fc..d5e488f7 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -55,7 +55,12 @@ impl Parse for VarOrConstantParser { match pair.as_rule() { Rule::var_or_constant => { - let inner = pair.into_inner().next().unwrap(); + let inner = pair.into_inner().next().ok_or_else(|| { + SemanticError::with_context( + "Expected var_or_constant inner content", + "variable or constant parsing", + ) + })?; Self::parse(inner, ctx) } Rule::identifier | Rule::constant_value => { @@ -106,7 +111,12 @@ pub struct ConstExprParser; impl Parse for ConstExprParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let inner = pair.into_inner().next().unwrap(); + let inner = pair.into_inner().next().ok_or_else(|| { + SemanticError::with_context( + "Expected const_expr inner content", + "constant expression parsing", + ) + })?; match inner.as_rule() { Rule::constant_value => { @@ -280,4 +290,20 @@ mod tests { let result = ConstExprParser::parse(pair, &mut ctx); assert!(result.is_err()); } + + #[test] + fn test_var_or_constant_parser_invalid_rule() { + let mut ctx = ParseContext::new(); + let input = "42"; + let mut pairs = LangParser::parse(Rule::number, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = VarOrConstantParser::parse(pair, &mut ctx); + assert!(result.is_err()); + if let Err(crate::parser::error::ParseError::SemanticError(error)) = result { + assert!(error.message.contains("Expected identifier or constant")); + } else { + panic!("Expected SemanticError"); + } + } } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 0f2b90e8..a6b51f6a 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -133,15 +133,18 @@ impl Parse for ForStatementParser { .as_str() .to_string(); - // Check for optional reverse clause + // Check for optional reverse clause by collecting remaining items let mut rev = false; - if let Some(next_peek) = inner.clone().next() - && next_peek.as_rule() == Rule::rev_clause - { + let remaining_items: Vec<_> = inner.collect(); + let mut item_index = 0; + + if !remaining_items.is_empty() && remaining_items[0].as_rule() == Rule::rev_clause { rev = true; - inner.next(); // Consume the rev clause + item_index = 1; } + let mut inner = remaining_items.into_iter().skip(item_index); + let start = ExpressionParser::parse(next_inner_pair(&mut inner, "loop start")?, ctx)?; let end = ExpressionParser::parse(next_inner_pair(&mut inner, "loop end")?, ctx)?; @@ -295,3 +298,111 @@ impl Parse for AssertNotEqParser { Ok(Line::Assert(Boolean::Different { left, right })) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::Line; + use crate::parser::grammar::{LangParser, Rule}; + use pest::Parser; + + #[test] + fn test_for_loop_with_rev_clause() { + let mut ctx = ParseContext::new(); + let input = r#"for i in rev 0..10 { break; }"#; + let mut pairs = LangParser::parse(Rule::for_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::ForLoop { + iterator, + start: _, + end: _, + body: _, + rev, + unroll, + } = result + { + assert_eq!(iterator, "i"); + assert!(rev); + assert!(!unroll); + } else { + panic!("Expected ForLoop"); + } + } + + #[test] + fn test_for_loop_without_rev_clause() { + let mut ctx = ParseContext::new(); + let input = r#"for i in 0..10 { break; }"#; + let mut pairs = LangParser::parse(Rule::for_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::ForLoop { + iterator, + start: _, + end: _, + body: _, + rev, + unroll, + } = result + { + assert_eq!(iterator, "i"); + assert!(!rev); + assert!(!unroll); + } else { + panic!("Expected ForLoop"); + } + } + + #[test] + fn test_for_loop_with_unroll_clause() { + let mut ctx = ParseContext::new(); + let input = r#"for i in 0..10 unroll { break; }"#; + let mut pairs = LangParser::parse(Rule::for_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::ForLoop { + iterator, + start: _, + end: _, + body: _, + rev, + unroll, + } = result + { + assert_eq!(iterator, "i"); + assert!(!rev); + assert!(unroll); + } else { + panic!("Expected ForLoop"); + } + } + + #[test] + fn test_for_loop_with_rev_and_unroll() { + let mut ctx = ParseContext::new(); + let input = r#"for i in rev 0..10 unroll { break; }"#; + let mut pairs = LangParser::parse(Rule::for_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::ForLoop { + iterator, + start: _, + end: _, + body: _, + rev, + unroll, + } = result + { + assert_eq!(iterator, "i"); + assert!(rev); + assert!(unroll); + } else { + panic!("Expected ForLoop"); + } + } +} From 135d9ac2a710eeafd53bd8096cb656b16bfd5b86 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 21:20:23 +0200 Subject: [PATCH 07/39] more tests --- .../src/parser/parsers/expression.rs | 238 ++++++++++++++ .../src/parser/parsers/function.rs | 179 +++++++++++ .../src/parser/parsers/statement.rs | 290 +++++++++++++++++- 3 files changed, 695 insertions(+), 12 deletions(-) diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index e3ce0d42..dc99c181 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -119,3 +119,241 @@ impl Parse for Log2CeilParser { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::HighLevelOperation; + use crate::lang::{Expression, SimpleExpr}; + use crate::parser::grammar::{LangParser, Rule}; + use pest::Parser; + + #[test] + fn test_expression_parser_primary() { + let mut ctx = ParseContext::new(); + let input = "42"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, Expression::scalar(42)); + } + + #[test] + fn test_expression_parser_add() { + let mut ctx = ParseContext::new(); + let input = "10 + 20"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + assert_eq!(*left, Expression::scalar(10)); + assert_eq!(operation, HighLevelOperation::Add); + assert_eq!(*right, Expression::scalar(20)); + } else { + panic!("Expected Binary expression"); + } + } + + #[test] + fn test_expression_parser_sub() { + let mut ctx = ParseContext::new(); + let input = "30 - 5"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + assert_eq!(*left, Expression::scalar(30)); + assert_eq!(operation, HighLevelOperation::Sub); + assert_eq!(*right, Expression::scalar(5)); + } else { + panic!("Expected Binary expression"); + } + } + + #[test] + fn test_expression_parser_mul() { + let mut ctx = ParseContext::new(); + let input = "6 * 7"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + assert_eq!(*left, Expression::scalar(6)); + assert_eq!(operation, HighLevelOperation::Mul); + assert_eq!(*right, Expression::scalar(7)); + } else { + panic!("Expected Binary expression"); + } + } + + #[test] + fn test_expression_parser_mod() { + let mut ctx = ParseContext::new(); + let input = "15 % 4"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + assert_eq!(*left, Expression::scalar(15)); + assert_eq!(operation, HighLevelOperation::Mod); + assert_eq!(*right, Expression::scalar(4)); + } else { + panic!("Expected Binary expression"); + } + } + + #[test] + fn test_expression_parser_div() { + let mut ctx = ParseContext::new(); + let input = "20 / 4"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + assert_eq!(*left, Expression::scalar(20)); + assert_eq!(operation, HighLevelOperation::Div); + assert_eq!(*right, Expression::scalar(4)); + } else { + panic!("Expected Binary expression"); + } + } + + #[test] + fn test_expression_parser_exp() { + let mut ctx = ParseContext::new(); + let input = "2 ** 8"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + assert_eq!(*left, Expression::scalar(2)); + assert_eq!(operation, HighLevelOperation::Exp); + assert_eq!(*right, Expression::scalar(8)); + } else { + panic!("Expected Binary expression"); + } + } + + #[test] + fn test_primary_expression_parser_parentheses() { + let mut ctx = ParseContext::new(); + let input = "(42)"; + let mut pairs = LangParser::parse(Rule::primary, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = PrimaryExpressionParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, Expression::scalar(42)); + } + + #[test] + fn test_primary_expression_parser_variable() { + let mut ctx = ParseContext::new(); + let input = "x"; + let mut pairs = LangParser::parse(Rule::primary, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = PrimaryExpressionParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, Expression::Value(SimpleExpr::Var("x".to_string()))); + } + + #[test] + fn test_array_access_parser() { + let mut ctx = ParseContext::new(); + let input = "arr[10]"; + let mut pairs = LangParser::parse(Rule::array_access_expr, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ArrayAccessParser::parse(pair, &mut ctx).unwrap(); + if let Expression::ArrayAccess { array, index } = result { + assert_eq!(array, SimpleExpr::Var("arr".to_string())); + assert_eq!(*index, Expression::scalar(10)); + } else { + panic!("Expected ArrayAccess"); + } + } + + #[test] + fn test_log2_ceil_parser() { + let mut ctx = ParseContext::new(); + let input = "log2_ceil(16)"; + let mut pairs = LangParser::parse(Rule::log2_ceil_expr, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = Log2CeilParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Log2Ceil { value } = result { + assert_eq!(*value, Expression::scalar(16)); + } else { + panic!("Expected Log2Ceil"); + } + } + + #[test] + fn test_binary_expression_parser_chain() { + let mut ctx = ParseContext::new(); + let input = "1 + 2 + 3"; + let mut pairs = LangParser::parse(Rule::expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ExpressionParser::parse(pair, &mut ctx).unwrap(); + if let Expression::Binary { + left, + operation, + right, + } = result + { + if let Expression::Binary { + left: inner_left, + operation: inner_op, + right: inner_right, + } = *left + { + assert_eq!(*inner_left, Expression::scalar(1)); + assert_eq!(inner_op, HighLevelOperation::Add); + assert_eq!(*inner_right, Expression::scalar(2)); + } else { + panic!("Expected nested Binary expression"); + } + assert_eq!(operation, HighLevelOperation::Add); + assert_eq!(*right, Expression::scalar(3)); + } else { + panic!("Expected Binary expression"); + } + } +} diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 5881bd71..6a090cc5 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -268,6 +268,8 @@ impl Parse> for TupleExpressionParser { mod tests { use super::*; use crate::lang::{Expression, Line}; + use crate::parser::grammar::{LangParser, Rule}; + use pest::Parser; #[test] fn test_malloc_vec_no_return_data() { @@ -389,4 +391,181 @@ mod tests { panic!("Expected SemanticError"); } } + + #[test] + fn test_malloc_builtin() { + let args = vec![Expression::scalar(200)]; + let return_data = vec!["mem".to_string()]; + + let result = + FunctionCallParser::handle_builtin_function("malloc".to_string(), args, return_data) + .unwrap(); + + if let Line::MAlloc { + var, + size, + vectorized, + vectorized_len, + } = result + { + assert_eq!(var, "mem"); + assert_eq!(size, Expression::scalar(200)); + assert!(!vectorized); + assert_eq!(vectorized_len, Expression::zero()); + } else { + panic!("Expected MAlloc"); + } + } + + #[test] + fn test_print_builtin() { + let args = vec![ + Expression::scalar(42), + Expression::Value(crate::lang::SimpleExpr::Var("x".to_string())), + ]; + let return_data = vec![]; + + let result = FunctionCallParser::handle_builtin_function( + "print".to_string(), + args.clone(), + return_data, + ) + .unwrap(); + + if let Line::Print { line_info, content } = result { + assert_eq!(line_info, "print"); + assert_eq!(content, args); + } else { + panic!("Expected Print"); + } + } + + #[test] + fn test_decompose_bits_builtin() { + let args = vec![Expression::scalar(255)]; + let return_data = vec!["bits".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "decompose_bits".to_string(), + args.clone(), + return_data, + ) + .unwrap(); + + if let Line::DecomposeBits { var, to_decompose } = result { + assert_eq!(var, "bits"); + assert_eq!(to_decompose, args); + } else { + panic!("Expected DecomposeBits"); + } + } + + #[test] + fn test_counter_hint_builtin() { + let args = vec![]; + let return_data = vec!["counter".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "counter_hint".to_string(), + args, + return_data, + ) + .unwrap(); + + if let Line::CounterHint { var } = result { + assert_eq!(var, "counter"); + } else { + panic!("Expected CounterHint"); + } + } + + #[test] + fn test_panic_builtin() { + let args = vec![]; + let return_data = vec![]; + + let result = + FunctionCallParser::handle_builtin_function("panic".to_string(), args, return_data) + .unwrap(); + + assert_eq!(result, Line::Panic); + } + + #[test] + fn test_regular_function_call() { + let args = vec![Expression::scalar(1), Expression::scalar(2)]; + let return_data = vec!["result".to_string()]; + + let result = FunctionCallParser::handle_builtin_function( + "my_function".to_string(), + args.clone(), + return_data.clone(), + ) + .unwrap(); + + if let Line::FunctionCall { + function_name, + args: call_args, + return_data: call_return, + } = result + { + assert_eq!(function_name, "my_function"); + assert_eq!(call_args, args); + assert_eq!(call_return, return_data); + } else { + panic!("Expected FunctionCall"); + } + } + + #[test] + fn test_tuple_expression_parser() { + let mut ctx = ParseContext::new(); + let input = "42, x, 100"; + let mut pairs = LangParser::parse(Rule::tuple_expression, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = TupleExpressionParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result[0], Expression::scalar(42)); + assert_eq!( + result[1], + Expression::Value(crate::lang::SimpleExpr::Var("x".to_string())) + ); + assert_eq!(result[2], Expression::scalar(100)); + } + + #[test] + fn test_parameter_parser_regular() { + let mut ctx = ParseContext::new(); + let input = "param1"; + let mut pairs = LangParser::parse(Rule::parameter, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ParameterParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, ("param1".to_string(), false)); + } + + #[test] + fn test_parameter_parser_const() { + let mut ctx = ParseContext::new(); + let input = "const param2"; + let mut pairs = LangParser::parse(Rule::parameter, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ParameterParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, ("param2".to_string(), true)); + } + + #[test] + fn test_return_count_parser() { + let mut ctx = ParseContext::new(); + ctx.add_constant("RETURN_COUNT".to_string(), 3).unwrap(); + + let input = "-> 3"; + let mut pairs = LangParser::parse(Rule::return_count, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ReturnCountParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, 3); + } } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index a6b51f6a..ce1e1955 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -316,14 +316,17 @@ mod tests { let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); if let Line::ForLoop { iterator, - start: _, - end: _, - body: _, + start, + end, + body, rev, unroll, } = result { assert_eq!(iterator, "i"); + assert_eq!(start, crate::lang::Expression::scalar(0)); + assert_eq!(end, crate::lang::Expression::scalar(10)); + assert_eq!(body.len(), 2); // LocationReport + Break assert!(rev); assert!(!unroll); } else { @@ -341,14 +344,17 @@ mod tests { let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); if let Line::ForLoop { iterator, - start: _, - end: _, - body: _, + start, + end, + body, rev, unroll, } = result { assert_eq!(iterator, "i"); + assert_eq!(start, crate::lang::Expression::scalar(0)); + assert_eq!(end, crate::lang::Expression::scalar(10)); + assert_eq!(body.len(), 2); // LocationReport + Break assert!(!rev); assert!(!unroll); } else { @@ -366,14 +372,17 @@ mod tests { let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); if let Line::ForLoop { iterator, - start: _, - end: _, - body: _, + start, + end, + body, rev, unroll, } = result { assert_eq!(iterator, "i"); + assert_eq!(start, crate::lang::Expression::scalar(0)); + assert_eq!(end, crate::lang::Expression::scalar(10)); + assert_eq!(body.len(), 2); // LocationReport + Break assert!(!rev); assert!(unroll); } else { @@ -391,18 +400,275 @@ mod tests { let result = ForStatementParser::parse(pair, &mut ctx).unwrap(); if let Line::ForLoop { iterator, - start: _, - end: _, - body: _, + start, + end, + body, rev, unroll, } = result { assert_eq!(iterator, "i"); + assert_eq!(start, crate::lang::Expression::scalar(0)); + assert_eq!(end, crate::lang::Expression::scalar(10)); + assert_eq!(body.len(), 2); // LocationReport + Break assert!(rev); assert!(unroll); } else { panic!("Expected ForLoop"); } } + + #[test] + fn test_statement_parser_break_statement() { + let mut ctx = ParseContext::new(); + let input = "break;"; + let mut pairs = LangParser::parse(Rule::statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = StatementParser::parse(pair, &mut ctx).unwrap(); + assert_eq!(result, Line::Break); + } + + #[test] + fn test_statement_parser_continue_statement() { + let mut ctx = ParseContext::new(); + let input = "continue;"; + let mut pairs = LangParser::parse(Rule::statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = StatementParser::parse(pair, &mut ctx); + assert!(result.is_err()); + if let Err(crate::parser::error::ParseError::SemanticError(error)) = result { + assert!( + error + .message + .contains("Continue statement not implemented yet") + ); + } else { + panic!("Expected SemanticError"); + } + } + + #[test] + fn test_assignment_parser() { + let mut ctx = ParseContext::new(); + let input = "x = 42;"; + let mut pairs = LangParser::parse(Rule::single_assignment, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = AssignmentParser::parse(pair, &mut ctx).unwrap(); + if let Line::Assignment { var, value } = result { + assert_eq!(var, "x"); + assert_eq!(value, crate::lang::Expression::scalar(42)); + } else { + panic!("Expected Assignment"); + } + } + + #[test] + fn test_array_assign_parser() { + let mut ctx = ParseContext::new(); + let input = "arr[5] = 100;"; + let mut pairs = LangParser::parse(Rule::array_assign, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ArrayAssignParser::parse(pair, &mut ctx).unwrap(); + if let Line::ArrayAssign { + array, + index, + value, + } = result + { + assert_eq!(array, crate::lang::SimpleExpr::Var("arr".to_string())); + assert_eq!(index, crate::lang::Expression::scalar(5)); + assert_eq!(value, crate::lang::Expression::scalar(100)); + } else { + panic!("Expected ArrayAssign"); + } + } + + #[test] + fn test_assert_eq_parser() { + let mut ctx = ParseContext::new(); + let input = "assert 10 == 20;"; + let mut pairs = LangParser::parse(Rule::assert_eq_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = AssertEqParser::parse(pair, &mut ctx).unwrap(); + if let Line::Assert(crate::lang::Boolean::Equal { left, right }) = result { + assert_eq!(left, crate::lang::Expression::scalar(10)); + assert_eq!(right, crate::lang::Expression::scalar(20)); + } else { + panic!("Expected Assert with Equal condition"); + } + } + + #[test] + fn test_assert_not_eq_parser() { + let mut ctx = ParseContext::new(); + let input = "assert 10 != 20;"; + let mut pairs = LangParser::parse(Rule::assert_not_eq_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = AssertNotEqParser::parse(pair, &mut ctx).unwrap(); + if let Line::Assert(crate::lang::Boolean::Different { left, right }) = result { + assert_eq!(left, crate::lang::Expression::scalar(10)); + assert_eq!(right, crate::lang::Expression::scalar(20)); + } else { + panic!("Expected Assert with Different condition"); + } + } + + #[test] + fn test_condition_parser_equal() { + let mut ctx = ParseContext::new(); + let input = "x == 5"; + let mut pairs = LangParser::parse(Rule::condition, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ConditionParser::parse(pair, &mut ctx).unwrap(); + if let crate::lang::Boolean::Equal { left, right } = result { + assert_eq!( + left, + crate::lang::Expression::Value(crate::lang::SimpleExpr::Var("x".to_string())) + ); + assert_eq!(right, crate::lang::Expression::scalar(5)); + } else { + panic!("Expected Equal condition"); + } + } + + #[test] + fn test_condition_parser_different() { + let mut ctx = ParseContext::new(); + let input = "y != 10"; + let mut pairs = LangParser::parse(Rule::condition, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ConditionParser::parse(pair, &mut ctx).unwrap(); + if let crate::lang::Boolean::Different { left, right } = result { + assert_eq!( + left, + crate::lang::Expression::Value(crate::lang::SimpleExpr::Var("y".to_string())) + ); + assert_eq!(right, crate::lang::Expression::scalar(10)); + } else { + panic!("Expected Different condition"); + } + } + + #[test] + fn test_return_statement_parser_empty() { + let mut ctx = ParseContext::new(); + let input = "return;"; + let mut pairs = LangParser::parse(Rule::return_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ReturnStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::FunctionRet { return_data } = result { + assert!(return_data.is_empty()); + } else { + panic!("Expected FunctionRet"); + } + } + + #[test] + fn test_return_statement_parser_with_values() { + let mut ctx = ParseContext::new(); + let input = "return 42, 100;"; + let mut pairs = LangParser::parse(Rule::return_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = ReturnStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::FunctionRet { return_data } = result { + assert_eq!(return_data.len(), 2); + assert_eq!(return_data[0], crate::lang::Expression::scalar(42)); + assert_eq!(return_data[1], crate::lang::Expression::scalar(100)); + } else { + panic!("Expected FunctionRet"); + } + } + + #[test] + fn test_match_statement_parser() { + let mut ctx = ParseContext::new(); + let input = r#"match x { 0 => { break; } 1 => { break; } }"#; + let mut pairs = LangParser::parse(Rule::match_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = MatchStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::Match { value, arms } = result { + assert_eq!( + value, + crate::lang::Expression::Value(crate::lang::SimpleExpr::Var("x".to_string())) + ); + assert_eq!(arms.len(), 2); + assert_eq!(arms[0].0, 0); + assert_eq!(arms[1].0, 1); + assert_eq!(arms[0].1.len(), 2); // LocationReport + Break + assert_eq!(arms[1].1.len(), 2); // LocationReport + Break + } else { + panic!("Expected Match"); + } + } + + #[test] + fn test_if_statement_parser_no_else() { + let mut ctx = ParseContext::new(); + let input = r#"if x == 0 { break; }"#; + let mut pairs = LangParser::parse(Rule::if_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = IfStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::IfCondition { + condition, + then_branch, + else_branch, + } = result + { + if let crate::lang::Boolean::Equal { left, right } = condition { + assert_eq!( + left, + crate::lang::Expression::Value(crate::lang::SimpleExpr::Var("x".to_string())) + ); + assert_eq!(right, crate::lang::Expression::scalar(0)); + } else { + panic!("Expected Equal condition"); + } + assert_eq!(then_branch.len(), 2); // LocationReport + Break + assert!(else_branch.is_empty()); + } else { + panic!("Expected IfCondition"); + } + } + + #[test] + fn test_if_statement_parser_with_else() { + let mut ctx = ParseContext::new(); + let input = r#"if x == 0 { break; } else { break; }"#; + let mut pairs = LangParser::parse(Rule::if_statement, input).unwrap(); + let pair = pairs.next().unwrap(); + + let result = IfStatementParser::parse(pair, &mut ctx).unwrap(); + if let Line::IfCondition { + condition, + then_branch, + else_branch, + } = result + { + if let crate::lang::Boolean::Equal { left, right } = condition { + assert_eq!( + left, + crate::lang::Expression::Value(crate::lang::SimpleExpr::Var("x".to_string())) + ); + assert_eq!(right, crate::lang::Expression::scalar(0)); + } else { + panic!("Expected Equal condition"); + } + assert_eq!(then_branch.len(), 2); // LocationReport + Break + assert_eq!(else_branch.len(), 2); // LocationReport + Break + } else { + panic!("Expected IfCondition"); + } + } } From 422229265047d5ced98ce7aa3aa716311b40e2d5 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 21:28:32 +0200 Subject: [PATCH 08/39] small fix --- crates/lean_compiler/src/parser/grammar.rs | 27 ++---- .../src/parser/parsers/statement.rs | 86 ++++++------------- 2 files changed, 31 insertions(+), 82 deletions(-) diff --git a/crates/lean_compiler/src/parser/grammar.rs b/crates/lean_compiler/src/parser/grammar.rs index 5829e12a..327db5b9 100644 --- a/crates/lean_compiler/src/parser/grammar.rs +++ b/crates/lean_compiler/src/parser/grammar.rs @@ -18,17 +18,6 @@ pub fn get_location(pair: &ParsePair<'_>) -> (usize, usize) { pair.line_col() } -/// Utility function to safely get the next inner element from a parser. -pub fn next_inner<'i>( - mut pairs: impl Iterator>, - expected: &str, -) -> Option> { - pairs.next().or_else(|| { - eprintln!("Warning: Expected {} but found nothing", expected); - None - }) -} - /// Utility function to parse the main program structure. pub fn parse_source(input: &str) -> Result, Box>> { let mut pairs = LangParser::parse(Rule::program, input)?; @@ -57,22 +46,18 @@ mod tests { } #[test] - fn test_next_inner_found() { + fn test_get_location_functionality() { let input = "fn main() {}"; if let Ok(pair) = parse_source(input) { let mut inner = pair.into_inner(); - let result = next_inner(&mut inner, "function"); - assert!(result.is_some()); + if let Some(func_pair) = inner.next() { + let (line, col) = get_location(&func_pair); + assert_eq!(line, 1); + assert_eq!(col, 1); + } } } - #[test] - fn test_next_inner_not_found() { - let empty_iter = std::iter::empty(); - let result = next_inner(empty_iter, "missing"); - assert!(result.is_none()); - } - #[test] fn test_parse_source_valid() { let input = "fn main() {}"; diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index ce1e1955..dc0e0074 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -10,6 +10,21 @@ use crate::{ }, }; +/// Add a statement with location tracking. +fn add_statement_with_location( + lines: &mut Vec, + pair: ParsePair<'_>, + ctx: &mut ParseContext, +) -> ParseResult<()> { + let location = pair.line_col().0; + let line = StatementParser::parse(pair, ctx)?; + + lines.push(Line::LocationReport { location }); + lines.push(line); + + Ok(()) +} + /// Parser for all statement types. pub struct StatementParser; @@ -86,12 +101,12 @@ impl Parse for IfStatementParser { for item in inner { match item.as_rule() { Rule::statement => { - Self::add_statement_with_location(&mut then_branch, item, ctx)?; + add_statement_with_location(&mut then_branch, item, ctx)?; } Rule::else_clause => { for else_item in item.into_inner() { if else_item.as_rule() == Rule::statement { - Self::add_statement_with_location(&mut else_branch, else_item, ctx)?; + add_statement_with_location(&mut else_branch, else_item, ctx)?; } } } @@ -107,22 +122,6 @@ impl Parse for IfStatementParser { } } -impl IfStatementParser { - fn add_statement_with_location( - lines: &mut Vec, - pair: ParsePair<'_>, - ctx: &mut ParseContext, - ) -> ParseResult<()> { - let location = pair.line_col().0; - let line = StatementParser::parse(pair, ctx)?; - - lines.push(Line::LocationReport { location }); - lines.push(line); - - Ok(()) - } -} - /// Parser for for-loop statements. pub struct ForStatementParser; @@ -133,18 +132,15 @@ impl Parse for ForStatementParser { .as_str() .to_string(); - // Check for optional reverse clause by collecting remaining items + // Check for optional reverse clause using efficient peek let mut rev = false; - let remaining_items: Vec<_> = inner.collect(); - let mut item_index = 0; - - if !remaining_items.is_empty() && remaining_items[0].as_rule() == Rule::rev_clause { - rev = true; - item_index = 1; + if let Some(peeked) = inner.peek() { + if peeked.as_rule() == Rule::rev_clause { + rev = true; + inner.next(); // Consume the rev clause + } } - let mut inner = remaining_items.into_iter().skip(item_index); - let start = ExpressionParser::parse(next_inner_pair(&mut inner, "loop start")?, ctx)?; let end = ExpressionParser::parse(next_inner_pair(&mut inner, "loop end")?, ctx)?; @@ -157,7 +153,7 @@ impl Parse for ForStatementParser { unroll = true; } Rule::statement => { - Self::add_statement_with_location(&mut body, item, ctx)?; + add_statement_with_location(&mut body, item, ctx)?; } _ => {} } @@ -174,22 +170,6 @@ impl Parse for ForStatementParser { } } -impl ForStatementParser { - fn add_statement_with_location( - lines: &mut Vec, - pair: ParsePair<'_>, - ctx: &mut ParseContext, - ) -> ParseResult<()> { - let location = pair.line_col().0; - let line = StatementParser::parse(pair, ctx)?; - - lines.push(Line::LocationReport { location }); - lines.push(line); - - Ok(()) - } -} - /// Parser for match statements with pattern matching. pub struct MatchStatementParser; @@ -209,7 +189,7 @@ impl Parse for MatchStatementParser { let mut statements = Vec::new(); for stmt in arm_inner { if stmt.as_rule() == Rule::statement { - Self::add_statement_with_location(&mut statements, stmt, ctx)?; + add_statement_with_location(&mut statements, stmt, ctx)?; } } @@ -221,22 +201,6 @@ impl Parse for MatchStatementParser { } } -impl MatchStatementParser { - fn add_statement_with_location( - lines: &mut Vec, - pair: ParsePair<'_>, - ctx: &mut ParseContext, - ) -> ParseResult<()> { - let location = pair.line_col().0; - let line = StatementParser::parse(pair, ctx)?; - - lines.push(Line::LocationReport { location }); - lines.push(line); - - Ok(()) - } -} - /// Parser for return statements. pub struct ReturnStatementParser; From 6fa2bf676be704cb8b2b0431454c23053b46c3dd Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 22:12:49 +0200 Subject: [PATCH 09/39] compiler: simplify submodule --- crates/lean_compiler/src/a_simplify_lang.rs | 2218 ----------------- .../src/b_compile_intermediate.rs | 2 +- crates/lean_compiler/src/lib.rs | 5 +- crates/lean_compiler/src/simplify/mod.rs | 53 + crates/lean_compiler/src/simplify/simplify.rs | 758 ++++++ .../src/simplify/transformations.rs | 492 ++++ crates/lean_compiler/src/simplify/types.rs | 861 +++++++ crates/lean_compiler/src/simplify/unroll.rs | 784 ++++++ .../lean_compiler/src/simplify/utilities.rs | 927 +++++++ 9 files changed, 3879 insertions(+), 2221 deletions(-) delete mode 100644 crates/lean_compiler/src/a_simplify_lang.rs create mode 100644 crates/lean_compiler/src/simplify/mod.rs create mode 100644 crates/lean_compiler/src/simplify/simplify.rs create mode 100644 crates/lean_compiler/src/simplify/transformations.rs create mode 100644 crates/lean_compiler/src/simplify/types.rs create mode 100644 crates/lean_compiler/src/simplify/unroll.rs create mode 100644 crates/lean_compiler/src/simplify/utilities.rs diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs deleted file mode 100644 index 4559ebb8..00000000 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ /dev/null @@ -1,2218 +0,0 @@ -use crate::{ - Counter, F, - ir::HighLevelOperation, - lang::{ - Boolean, ConstExpression, ConstMallocLabel, Expression, Function, Line, Program, - SimpleExpr, Var, - }, - precompiles::Precompile, -}; -use lean_vm::SourceLineNumber; -use std::{ - collections::{BTreeMap, BTreeSet}, - fmt::{Display, Formatter}, -}; -use utils::ToUsize; - -#[derive(Debug, Clone)] -pub struct SimpleProgram { - pub functions: BTreeMap, -} - -#[derive(Debug, Clone)] -pub struct SimpleFunction { - pub name: String, - pub arguments: Vec, - pub n_returned_vars: usize, - pub instructions: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum VarOrConstMallocAccess { - Var(Var), - ConstMallocAccess { - malloc_label: ConstMallocLabel, - offset: ConstExpression, - }, -} - -impl From for SimpleExpr { - fn from(var_or_const: VarOrConstMallocAccess) -> Self { - match var_or_const { - VarOrConstMallocAccess::Var(var) => Self::Var(var), - VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - } => Self::ConstMallocAccess { - malloc_label, - offset, - }, - } - } -} - -impl TryInto for SimpleExpr { - type Error = (); - - fn try_into(self) -> Result { - match self { - Self::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), - Self::ConstMallocAccess { - malloc_label, - offset, - } => Ok(VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - }), - _ => Err(()), - } - } -} - -impl From for VarOrConstMallocAccess { - fn from(var: Var) -> Self { - Self::Var(var) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum SimpleLine { - Match { - value: SimpleExpr, - arms: Vec>, // patterns = 0, 1, ... - }, - Assignment { - var: VarOrConstMallocAccess, - operation: HighLevelOperation, - arg0: SimpleExpr, - arg1: SimpleExpr, - }, - RawAccess { - res: SimpleExpr, - index: SimpleExpr, - shift: ConstExpression, - }, // res = memory[index + shift] - IfNotZero { - condition: SimpleExpr, - then_branch: Vec, - else_branch: Vec, - }, - FunctionCall { - function_name: String, - args: Vec, - return_data: Vec, - }, - FunctionRet { - return_data: Vec, - }, - Precompile { - precompile: Precompile, - args: Vec, - }, - Panic, - // Hints - DecomposeBits { - var: Var, // a pointer to 31 * len(to_decompose) field elements, containing the bits of "to_decompose" - to_decompose: Vec, - label: ConstMallocLabel, - }, - CounterHint { - var: Var, - }, - Print { - line_info: String, - content: Vec, - }, - HintMAlloc { - var: Var, - size: SimpleExpr, - vectorized: bool, - vectorized_len: SimpleExpr, - }, - ConstMalloc { - // always not vectorized - var: Var, - size: ConstExpression, - label: ConstMallocLabel, - }, - // noop, debug purpose only - LocationReport { - location: SourceLineNumber, - }, -} - -pub fn simplify_program(mut program: Program) -> SimpleProgram { - handle_inlined_functions(&mut program); - handle_const_arguments(&mut program); - let mut new_functions = BTreeMap::new(); - let mut counters = Counters::default(); - let mut const_malloc = ConstMalloc::default(); - for (name, func) in &program.functions { - let mut array_manager = ArrayManager::default(); - let simplified_instructions = simplify_lines( - &func.body, - &mut counters, - &mut new_functions, - false, - &mut array_manager, - &mut const_malloc, - ); - let arguments = func - .arguments - .iter() - .map(|(v, is_const)| { - assert!(!is_const,); - v.clone() - }) - .collect::>(); - new_functions.insert( - name.clone(), - SimpleFunction { - name: name.clone(), - arguments, - n_returned_vars: func.n_returned_vars, - instructions: simplified_instructions, - }, - ); - const_malloc.map.clear(); - } - SimpleProgram { - functions: new_functions, - } -} - -#[derive(Debug, Clone, Default)] -struct Counters { - aux_vars: usize, - loops: usize, - unrolls: usize, -} - -#[derive(Debug, Clone, Default)] -struct ArrayManager { - counter: usize, - aux_vars: BTreeMap<(SimpleExpr, Expression), Var>, // (array, index) -> aux_var - valid: BTreeSet, // currently valid aux vars -} - -#[derive(Debug, Clone, Default)] -pub struct ConstMalloc { - counter: usize, - map: BTreeMap, - forbidden_vars: BTreeSet, // vars shared between branches of an if/else -} - -impl ArrayManager { - fn get_aux_var(&mut self, array: &SimpleExpr, index: &Expression) -> Var { - if let Some(var) = self.aux_vars.get(&(array.clone(), index.clone())) { - return var.clone(); - } - let new_var = format!("@arr_aux_{}", self.counter); - self.counter += 1; - self.aux_vars - .insert((array.clone(), index.clone()), new_var.clone()); - new_var - } -} - -fn simplify_lines( - lines: &[Line], - counters: &mut Counters, - new_functions: &mut BTreeMap, - in_a_loop: bool, - array_manager: &mut ArrayManager, - const_malloc: &mut ConstMalloc, -) -> Vec { - let mut res = Vec::new(); - for line in lines { - match line { - Line::Match { value, arms } => { - let simple_value = - simplify_expr(value, &mut res, counters, array_manager, const_malloc); - let mut simple_arms = vec![]; - for (i, (pattern, statements)) in arms.iter().enumerate() { - assert_eq!( - *pattern, i, - "match patterns should be consecutive, starting from 0" - ); - simple_arms.push(simplify_lines( - statements, - counters, - new_functions, - in_a_loop, - array_manager, - const_malloc, - )); - } - res.push(SimpleLine::Match { - value: simple_value, - arms: simple_arms, - }); - } - Line::Assignment { var, value } => match value { - Expression::Value(value) => { - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: HighLevelOperation::Add, - arg0: value.clone(), - arg1: SimpleExpr::zero(), - }); - } - Expression::ArrayAccess { array, index } => { - handle_array_assignment( - counters, - &mut res, - array.clone(), - index, - ArrayAccessType::VarIsAssigned(var.clone()), - array_manager, - const_malloc, - ); - } - Expression::Binary { - left, - operation, - right, - } => { - let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: *operation, - arg0: left, - arg1: right, - }); - } - Expression::Log2Ceil { .. } => unreachable!(), - }, - Line::ArrayAssign { - array, - index, - value, - } => { - handle_array_assignment( - counters, - &mut res, - array.clone(), - index, - ArrayAccessType::ArrayIsAssigned(value.clone()), - array_manager, - const_malloc, - ); - } - Line::Assert(boolean) => match boolean { - Boolean::Different { left, right } => { - let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - let diff_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: left, - arg1: right, - }); - res.push(SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: vec![], - else_branch: vec![SimpleLine::Panic], - }); - } - Boolean::Equal { left, right } => { - let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - let (var, other) = if let Ok(left) = left.clone().try_into() { - (left, right) - } else if let Ok(right) = right.clone().try_into() { - (right, left) - } else { - unreachable!("Weird: {:?}, {:?}", left, right) - }; - res.push(SimpleLine::Assignment { - var, - operation: HighLevelOperation::Add, - arg0: other, - arg1: SimpleExpr::zero(), - }); - } - }, - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - // Transform if a == b then X else Y into if a != b then Y else X - - let (left, right, then_branch, else_branch) = match condition { - Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched - Boolean::Different { left, right } => (left, right, then_branch, else_branch), - }; - - let left_simplified = - simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right_simplified = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - - let diff_var = format!("@diff_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: left_simplified, - arg1: right_simplified, - }); - - let forbidden_vars_before = const_malloc.forbidden_vars.clone(); - - let then_internal_vars = find_variable_usage(then_branch).0; - let else_internal_vars = find_variable_usage(else_branch).0; - let new_forbidden_vars = then_internal_vars - .intersection(&else_internal_vars) - .cloned() - .collect::>(); - - const_malloc.forbidden_vars.extend(new_forbidden_vars); - - let mut array_manager_then = array_manager.clone(); - let then_branch_simplified = simplify_lines( - then_branch, - counters, - new_functions, - in_a_loop, - &mut array_manager_then, - const_malloc, - ); - let mut array_manager_else = array_manager_then.clone(); - array_manager_else.valid = array_manager.valid.clone(); // Crucial: remove the access added in the IF branch - - let else_branch_simplified = simplify_lines( - else_branch, - counters, - new_functions, - in_a_loop, - &mut array_manager_else, - const_malloc, - ); - - const_malloc.forbidden_vars = forbidden_vars_before; - - *array_manager = array_manager_else.clone(); - // keep the intersection both branches - array_manager.valid = array_manager - .valid - .intersection(&array_manager_then.valid) - .cloned() - .collect(); - - res.push(SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: then_branch_simplified, - else_branch: else_branch_simplified, - }); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev, - unroll, - } => { - if *unroll { - let (internal_variables, _) = find_variable_usage(body); - let mut unrolled_lines = Vec::new(); - let start_evaluated = start.naive_eval().unwrap().to_usize(); - let end_evaluated = end.naive_eval().unwrap().to_usize(); - let unroll_index = counters.unrolls; - counters.unrolls += 1; - - let mut range = (start_evaluated..end_evaluated).collect::>(); - if *rev { - range.reverse(); - } - - for i in range { - let mut body_copy = body.clone(); - replace_vars_for_unroll( - &mut body_copy, - iterator, - unroll_index, - i, - &internal_variables, - ); - unrolled_lines.extend(simplify_lines( - &body_copy, - counters, - new_functions, - in_a_loop, - array_manager, - const_malloc, - )); - } - res.extend(unrolled_lines); - continue; - } - - if *rev { - unimplemented!("Reverse for non-unrolled loops are not implemented yet"); - } - - let mut loop_const_malloc = ConstMalloc { - counter: const_malloc.counter, - ..ConstMalloc::default() - }; - let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); - array_manager.valid.clear(); - let simplified_body = simplify_lines( - body, - counters, - new_functions, - true, - array_manager, - &mut loop_const_malloc, - ); - const_malloc.counter = loop_const_malloc.counter; - array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars - - let func_name = format!("@loop_{}", counters.loops); - counters.loops += 1; - - // Find variables used inside loop but defined outside - let (_, mut external_vars) = find_variable_usage(body); - - // Include variables in start/end - for expr in [start, end] { - for var in vars_in_expression(expr) { - external_vars.insert(var); - } - } - external_vars.remove(iterator); // Iterator is internal to loop - - let mut external_vars: Vec<_> = external_vars.into_iter().collect(); - - let start_simplified = - simplify_expr(start, &mut res, counters, array_manager, const_malloc); - let end_simplified = - simplify_expr(end, &mut res, counters, array_manager, const_malloc); - - for (simplified, original) in [ - (start_simplified.clone(), start.clone()), - (end_simplified.clone(), end.clone()), - ] { - if !matches!(original, Expression::Value(_)) { - // the simplified var is auxiliary - if let SimpleExpr::Var(var) = simplified { - external_vars.push(var); - } - } - } - - // Create function arguments: iterator + external variables - let mut func_args = vec![iterator.clone()]; - func_args.extend(external_vars.clone()); - - // Create recursive function body - let recursive_func = create_recursive_function( - func_name.clone(), - func_args, - iterator.clone(), - end_simplified, - simplified_body, - &external_vars, - ); - new_functions.insert(func_name.clone(), recursive_func); - - // Replace loop with initial function call - let mut call_args = vec![start_simplified]; - call_args.extend(external_vars.iter().map(|v| v.clone().into())); - - res.push(SimpleLine::FunctionCall { - function_name: func_name, - args: call_args, - return_data: vec![], - }); - } - Line::FunctionCall { - function_name, - args, - return_data, - } => { - let simplified_args = args - .iter() - .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::FunctionCall { - function_name: function_name.clone(), - args: simplified_args, - return_data: return_data.clone(), - }); - } - Line::FunctionRet { return_data } => { - assert!( - !in_a_loop, - "Function return inside a loop is not currently supported" - ); - let simplified_return_data = return_data - .iter() - .map(|ret| simplify_expr(ret, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::FunctionRet { - return_data: simplified_return_data, - }); - } - Line::Precompile { precompile, args } => { - let simplified_args = args - .iter() - .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::Precompile { - precompile: precompile.clone(), - args: simplified_args, - }); - } - Line::Print { line_info, content } => { - let simplified_content = content - .iter() - .map(|var| simplify_expr(var, &mut res, counters, array_manager, const_malloc)) - .collect::>(); - res.push(SimpleLine::Print { - line_info: line_info.clone(), - content: simplified_content, - }); - } - Line::Break => { - assert!(in_a_loop, "Break statement outside of a loop"); - res.push(SimpleLine::FunctionRet { - return_data: vec![], - }); - } - Line::MAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - let simplified_size = - simplify_expr(size, &mut res, counters, array_manager, const_malloc); - let simplified_vectorized_len = simplify_expr( - vectorized_len, - &mut res, - counters, - array_manager, - const_malloc, - ); - if simplified_size.is_constant() - && !*vectorized - && const_malloc.forbidden_vars.contains(var) - { - println!( - "TODO: Optimization missed: Requires to align const malloc in if/else branches" - ); - } - match simplified_size { - SimpleExpr::Constant(const_size) - if !*vectorized && !const_malloc.forbidden_vars.contains(var) => - { - // TODO do this optimization even if we are in an if/else branch - let label = const_malloc.counter; - const_malloc.counter += 1; - const_malloc.map.insert(var.clone(), label); - res.push(SimpleLine::ConstMalloc { - var: var.clone(), - size: const_size, - label, - }); - } - _ => { - res.push(SimpleLine::HintMAlloc { - var: var.clone(), - size: simplified_size, - vectorized: *vectorized, - vectorized_len: simplified_vectorized_len, - }); - } - } - } - Line::DecomposeBits { var, to_decompose } => { - assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); - let simplified_to_decompose = to_decompose - .iter() - .map(|expr| { - simplify_expr(expr, &mut res, counters, array_manager, const_malloc) - }) - .collect::>(); - let label = const_malloc.counter; - const_malloc.counter += 1; - const_malloc.map.insert(var.clone(), label); - res.push(SimpleLine::DecomposeBits { - var: var.clone(), - to_decompose: simplified_to_decompose, - label, - }); - } - Line::CounterHint { var } => { - res.push(SimpleLine::CounterHint { var: var.clone() }); - } - Line::Panic => { - res.push(SimpleLine::Panic); - } - Line::LocationReport { location } => { - res.push(SimpleLine::LocationReport { - location: *location, - }); - } - } - } - - res -} - -fn simplify_expr( - expr: &Expression, - lines: &mut Vec, - counters: &mut Counters, - array_manager: &mut ArrayManager, - const_malloc: &ConstMalloc, -) -> SimpleExpr { - match expr { - Expression::Value(value) => value.simplify_if_const(), - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array - && let Some(label) = const_malloc.map.get(array_var) - && let Ok(mut offset) = ConstExpression::try_from(*index.clone()) - { - offset = offset.try_naive_simplification(); - return SimpleExpr::ConstMallocAccess { - malloc_label: *label, - offset, - }; - } - - let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index] - - if !array_manager.valid.insert(aux_arr.clone()) { - return SimpleExpr::Var(aux_arr); - } - - handle_array_assignment( - counters, - lines, - array.clone(), - index, - ArrayAccessType::VarIsAssigned(aux_arr.clone()), - array_manager, - const_malloc, - ); - SimpleExpr::Var(aux_arr) - } - Expression::Binary { - left, - operation, - right, - } => { - let left_var = simplify_expr(left, lines, counters, array_manager, const_malloc); - let right_var = simplify_expr(right, lines, counters, array_manager, const_malloc); - - if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = - (&left_var, &right_var) - { - return SimpleExpr::Constant(ConstExpression::Binary { - left: Box::new(left_cst.clone()), - operation: *operation, - right: Box::new(right_cst.clone()), - }); - } - - let aux_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - lines.push(SimpleLine::Assignment { - var: aux_var.clone().into(), - operation: *operation, - arg0: left_var, - arg1: right_var, - }); - SimpleExpr::Var(aux_var) - } - Expression::Log2Ceil { value } => { - let const_value = simplify_expr(value, lines, counters, array_manager, const_malloc) - .as_constant() - .unwrap(); - SimpleExpr::Constant(ConstExpression::Log2Ceil { - value: Box::new(const_value), - }) - } - } -} - -/// Returns (internal_vars, external_vars) -pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { - let mut internal_vars = BTreeSet::new(); - let mut external_vars = BTreeSet::new(); - - let on_new_expr = - |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - for var in vars_in_expression(expr) { - if !internal_vars.contains(&var) { - external_vars.insert(var); - } - } - }; - - let on_new_condition = - |condition: &Boolean, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; - on_new_expr(left, internal_vars, external_vars); - on_new_expr(right, internal_vars, external_vars); - }; - - for line in lines { - match line { - Line::Match { value, arms } => { - on_new_expr(value, &internal_vars, &mut external_vars); - for (_, statements) in arms { - let (stmt_internal, stmt_external) = find_variable_usage(statements); - internal_vars.extend(stmt_internal); - external_vars.extend( - stmt_external - .into_iter() - .filter(|v| !internal_vars.contains(v)), - ); - } - } - Line::Assignment { var, value } => { - on_new_expr(value, &internal_vars, &mut external_vars); - internal_vars.insert(var.clone()); - } - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - on_new_condition(condition, &internal_vars, &mut external_vars); - - let (then_internal, then_external) = find_variable_usage(then_branch); - let (else_internal, else_external) = find_variable_usage(else_branch); - - internal_vars.extend(then_internal.union(&else_internal).cloned()); - external_vars.extend( - then_external - .union(&else_external) - .filter(|v| !internal_vars.contains(*v)) - .cloned(), - ); - } - Line::FunctionCall { - args, return_data, .. - } => { - for arg in args { - on_new_expr(arg, &internal_vars, &mut external_vars); - } - internal_vars.extend(return_data.iter().cloned()); - } - Line::Assert(condition) => { - on_new_condition(condition, &internal_vars, &mut external_vars); - } - Line::FunctionRet { return_data } => { - for ret in return_data { - on_new_expr(ret, &internal_vars, &mut external_vars); - } - } - Line::MAlloc { var, size, .. } => { - on_new_expr(size, &internal_vars, &mut external_vars); - internal_vars.insert(var.clone()); - } - Line::Precompile { - precompile: _, - args, - } => { - for arg in args { - on_new_expr(arg, &internal_vars, &mut external_vars); - } - } - Line::Print { content, .. } => { - for var in content { - on_new_expr(var, &internal_vars, &mut external_vars); - } - } - Line::DecomposeBits { var, to_decompose } => { - for expr in to_decompose { - on_new_expr(expr, &internal_vars, &mut external_vars); - } - internal_vars.insert(var.clone()); - } - Line::CounterHint { var } => { - internal_vars.insert(var.clone()); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev: _, - unroll: _, - } => { - let (body_internal, body_external) = find_variable_usage(body); - internal_vars.extend(body_internal); - internal_vars.insert(iterator.clone()); - external_vars.extend(body_external.difference(&internal_vars).cloned()); - on_new_expr(start, &internal_vars, &mut external_vars); - on_new_expr(end, &internal_vars, &mut external_vars); - } - Line::ArrayAssign { - array, - index, - value, - } => { - on_new_expr(&array.clone().into(), &internal_vars, &mut external_vars); - on_new_expr(index, &internal_vars, &mut external_vars); - on_new_expr(value, &internal_vars, &mut external_vars); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } - - (internal_vars, external_vars) -} - -fn inline_simple_expr( - simple_expr: &mut SimpleExpr, - args: &BTreeMap, - inlining_count: usize, -) { - if let SimpleExpr::Var(var) = simple_expr { - if let Some(replacement) = args.get(var) { - *simple_expr = replacement.clone(); - } else { - *var = format!("@inlined_var_{inlining_count}_{var}"); - } - } -} - -fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining_count: usize) { - match expr { - Expression::Value(value) => { - inline_simple_expr(value, args, inlining_count); - } - Expression::ArrayAccess { array, index } => { - inline_simple_expr(array, args, inlining_count); - inline_expr(index, args, inlining_count); - } - Expression::Binary { left, right, .. } => { - inline_expr(left, args, inlining_count); - inline_expr(right, args, inlining_count); - } - Expression::Log2Ceil { value } => { - inline_expr(value, args, inlining_count); - } - } -} - -pub fn inline_lines( - lines: &mut Vec, - args: &BTreeMap, - res: &[Var], - inlining_count: usize, -) { - let inline_condition = |condition: &mut Boolean| { - let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; - inline_expr(left, args, inlining_count); - inline_expr(right, args, inlining_count); - }; - - let inline_internal_var = |var: &mut Var| { - assert!( - !args.contains_key(var), - "Variable {var} is both an argument and assigned in the inlined function" - ); - *var = format!("@inlined_var_{inlining_count}_{var}"); - }; - - let mut lines_to_replace = vec![]; - for (i, line) in lines.iter_mut().enumerate() { - match line { - Line::Match { value, arms } => { - inline_expr(value, args, inlining_count); - for (_, statements) in arms { - inline_lines(statements, args, res, inlining_count); - } - } - Line::Assignment { var, value } => { - inline_expr(value, args, inlining_count); - inline_internal_var(var); - } - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - inline_condition(condition); - - inline_lines(then_branch, args, res, inlining_count); - inline_lines(else_branch, args, res, inlining_count); - } - Line::FunctionCall { - args: func_args, - return_data, - .. - } => { - for arg in func_args { - inline_expr(arg, args, inlining_count); - } - for return_var in return_data { - inline_internal_var(return_var); - } - } - Line::Assert(condition) => { - inline_condition(condition); - } - Line::FunctionRet { return_data } => { - assert_eq!(return_data.len(), res.len()); - - for expr in return_data.iter_mut() { - inline_expr(expr, args, inlining_count); - } - lines_to_replace.push(( - i, - res.iter() - .zip(return_data) - .map(|(res_var, expr)| Line::Assignment { - var: res_var.clone(), - value: expr.clone(), - }) - .collect::>(), - )); - } - Line::MAlloc { var, size, .. } => { - inline_expr(size, args, inlining_count); - inline_internal_var(var); - } - Line::Precompile { - precompile: _, - args: precompile_args, - } => { - for arg in precompile_args { - inline_expr(arg, args, inlining_count); - } - } - Line::Print { content, .. } => { - for var in content { - inline_expr(var, args, inlining_count); - } - } - Line::DecomposeBits { var, to_decompose } => { - for expr in to_decompose { - inline_expr(expr, args, inlining_count); - } - inline_internal_var(var); - } - Line::CounterHint { var } => { - inline_internal_var(var); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev: _, - unroll: _, - } => { - inline_lines(body, args, res, inlining_count); - inline_internal_var(iterator); - inline_expr(start, args, inlining_count); - inline_expr(end, args, inlining_count); - } - Line::ArrayAssign { - array, - index, - value, - } => { - inline_simple_expr(array, args, inlining_count); - inline_expr(index, args, inlining_count); - inline_expr(value, args, inlining_count); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } - for (i, new_lines) in lines_to_replace.into_iter().rev() { - lines.splice(i..=i, new_lines); - } -} - -fn vars_in_expression(expr: &Expression) -> BTreeSet { - let mut vars = BTreeSet::new(); - match expr { - Expression::Value(value) => { - if let SimpleExpr::Var(var) = value { - vars.insert(var.clone()); - } - } - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array) = array { - vars.insert(array.clone()); - } - vars.extend(vars_in_expression(index)); - } - Expression::Binary { left, right, .. } => { - vars.extend(vars_in_expression(left)); - vars.extend(vars_in_expression(right)); - } - Expression::Log2Ceil { value } => { - vars.extend(vars_in_expression(value)); - } - } - vars -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ArrayAccessType { - VarIsAssigned(Var), // var = array[index] - ArrayIsAssigned(Expression), // array[index] = expr -} - -fn handle_array_assignment( - counters: &mut Counters, - res: &mut Vec, - array: SimpleExpr, - index: &Expression, - access_type: ArrayAccessType, - array_manager: &mut ArrayManager, - const_malloc: &ConstMalloc, -) { - let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc); - - if let SimpleExpr::Constant(offset) = simplified_index.clone() - && let SimpleExpr::Var(array_var) = &array - && let Some(label) = const_malloc.map.get(array_var) - && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { - left, - operation, - right, - }) = &access_type - { - let arg0 = simplify_expr(left, res, counters, array_manager, const_malloc); - let arg1 = simplify_expr(right, res, counters, array_manager, const_malloc); - res.push(SimpleLine::Assignment { - var: VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *label, - offset, - }, - operation: *operation, - arg0, - arg1, - }); - return; - } - - let value_simplified = match access_type { - ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), - ArrayAccessType::ArrayIsAssigned(expr) => { - simplify_expr(&expr, res, counters, array_manager, const_malloc) - } - }; - - // TODO opti: in some case we could use ConstMallocAccess - - let (index_var, shift) = match simplified_index { - SimpleExpr::Constant(c) => (array, c), - _ => { - // Create pointer variable: ptr = array + index - let ptr_var = format!("@aux_var_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: ptr_var.clone().into(), - operation: HighLevelOperation::Add, - arg0: array, - arg1: simplified_index, - }); - (SimpleExpr::Var(ptr_var), ConstExpression::zero()) - } - }; - - res.push(SimpleLine::RawAccess { - res: value_simplified, - index: index_var, - shift, - }); -} - -fn create_recursive_function( - name: String, - args: Vec, - iterator: Var, - end: SimpleExpr, - mut body: Vec, - external_vars: &[Var], -) -> SimpleFunction { - // Add iterator increment - let next_iter = format!("@incremented_{iterator}"); - body.push(SimpleLine::Assignment { - var: next_iter.clone().into(), - operation: HighLevelOperation::Add, - arg0: iterator.clone().into(), - arg1: SimpleExpr::one(), - }); - - // Add recursive call - let mut recursive_args: Vec = vec![next_iter.into()]; - recursive_args.extend(external_vars.iter().map(|v| v.clone().into())); - - body.push(SimpleLine::FunctionCall { - function_name: name.clone(), - args: recursive_args, - return_data: vec![], - }); - body.push(SimpleLine::FunctionRet { - return_data: vec![], - }); - - let diff_var = format!("@diff_{iterator}"); - - let instructions = vec![ - SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: iterator.into(), - arg1: end, - }, - SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: body, - else_branch: vec![SimpleLine::FunctionRet { - return_data: vec![], - }], - }, - ]; - - SimpleFunction { - name, - arguments: args, - n_returned_vars: 0, - instructions, - } -} - -fn replace_vars_for_unroll_in_expr( - expr: &mut Expression, - iterator: &Var, - unroll_index: usize, - iterator_value: usize, - internal_vars: &BTreeSet, -) { - match expr { - Expression::Value(value_expr) => match value_expr { - SimpleExpr::Var(var) => { - if var == iterator { - *value_expr = SimpleExpr::Constant(ConstExpression::from(iterator_value)); - } else if internal_vars.contains(var) { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - } - } - SimpleExpr::Constant(_) | SimpleExpr::ConstMallocAccess { .. } => {} - }, - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } - } - - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Expression::Binary { left, right, .. } => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Expression::Log2Ceil { value } => { - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } -} - -fn replace_vars_for_unroll( - lines: &mut [Line], - iterator: &Var, - unroll_index: usize, - iterator_value: usize, - internal_vars: &BTreeSet, -) { - for line in lines { - match line { - Line::Match { value, arms } => { - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - for (_, statements) in arms { - replace_vars_for_unroll( - statements, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::Assignment { var, value } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::ArrayAssign { - // array[index] = value - array, - index, - value, - } => { - if let SimpleExpr::Var(array_var) = array { - assert!(array_var != iterator, "Weird"); - if internal_vars.contains(array_var) { - *array_var = - format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); - } - } - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - value, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::Assert(Boolean::Equal { left, right } | Boolean::Different { left, right }) => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::IfCondition { - condition: Boolean::Equal { left, right } | Boolean::Different { left, right }, - then_branch, - else_branch, - } => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - then_branch, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - else_branch, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::ForLoop { - iterator: other_iterator, - start, - end, - body, - rev: _, - unroll: _, - } => { - assert!(other_iterator != iterator); - *other_iterator = - format!("@unrolled_{unroll_index}_{iterator_value}_{other_iterator}"); - replace_vars_for_unroll_in_expr( - start, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - end, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll( - body, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::FunctionCall { - function_name: _, - args, - return_data, - } => { - // Function calls are not unrolled, so we don't need to change them - for arg in args { - replace_vars_for_unroll_in_expr( - arg, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - for ret in return_data { - *ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}"); - } - } - Line::FunctionRet { return_data } => { - for ret in return_data { - replace_vars_for_unroll_in_expr( - ret, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::Precompile { - precompile: _, - args, - } => { - for arg in args { - replace_vars_for_unroll_in_expr( - arg, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::Print { line_info, content } => { - // Print statements are not unrolled, so we don't need to change them - *line_info += &format!(" (unrolled {unroll_index} {iterator_value})"); - for var in content { - replace_vars_for_unroll_in_expr( - var, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::MAlloc { - var, - size, - vectorized: _, - vectorized_len, - } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr( - size, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - vectorized_len, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Line::DecomposeBits { var, to_decompose } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - for expr in to_decompose { - replace_vars_for_unroll_in_expr( - expr, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - } - Line::CounterHint { var } => { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - } - Line::Break | Line::Panic | Line::LocationReport { .. } => {} - } - } -} - -fn handle_inlined_functions(program: &mut Program) { - let inlined_functions = program - .functions - .iter() - .filter(|(_, func)| func.inlined) - .map(|(name, func)| (name.clone(), func.clone())) - .collect::>(); - - for func in inlined_functions.values() { - assert!( - !func.has_const_arguments(), - "Inlined functions with constant arguments are not supported yet" - ); - } - - // Process inline functions iteratively to handle dependencies - // Repeat until all inline function calls are resolved - let mut max_iterations = 10; - while max_iterations > 0 { - let mut any_changes = false; - - // Process non-inlined functions - for func in program.functions.values_mut() { - if !func.inlined { - let mut counter1 = Counter::new(); - let mut counter2 = Counter::new(); - let old_body = func.body.clone(); - - handle_inlined_functions_helper( - &mut func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); - - if func.body != old_body { - any_changes = true; - } - } - } - - // Process inlined functions that may call other inlined functions - // We need to update them so that when they get inlined later, they don't have unresolved calls - for func in program.functions.values_mut() { - if func.inlined { - let mut counter1 = Counter::new(); - let mut counter2 = Counter::new(); - let old_body = func.body.clone(); - - handle_inlined_functions_helper( - &mut func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); - - if func.body != old_body { - any_changes = true; - } - } - } - - if !any_changes { - break; - } - - max_iterations -= 1; - } - - assert!( - max_iterations > 0, - "Too many iterations processing inline functions" - ); - - // Remove all inlined functions from the program (they've been inlined) - for func_name in inlined_functions.keys() { - program.functions.remove(func_name); - } -} - -fn handle_inlined_functions_helper( - lines: &mut Vec, - inlined_functions: &BTreeMap, - inlined_var_counter: &mut Counter, - total_inlined_counter: &mut Counter, -) { - for i in (0..lines.len()).rev() { - match &mut lines[i] { - Line::FunctionCall { - function_name, - args, - return_data, - } => { - if let Some(func) = inlined_functions.get(&*function_name) { - let mut inlined_lines = vec![]; - - let mut simplified_args = vec![]; - for arg in args { - if let Expression::Value(simple_expr) = arg { - simplified_args.push(simple_expr.clone()); - } else { - let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); - inlined_lines.push(Line::Assignment { - var: aux_var.clone(), - value: arg.clone(), - }); - simplified_args.push(SimpleExpr::Var(aux_var)); - } - } - assert_eq!(simplified_args.len(), func.arguments.len()); - let inlined_args = func - .arguments - .iter() - .zip(&simplified_args) - .map(|((var, _), expr)| (var.clone(), expr.clone())) - .collect::>(); - let mut func_body = func.body.clone(); - inline_lines( - &mut func_body, - &inlined_args, - return_data, - total_inlined_counter.next(), - ); - inlined_lines.extend(func_body); - - lines.remove(i); // remove the call to the inlined function - lines.splice(i..i, inlined_lines); - } - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - handle_inlined_functions_helper( - then_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - handle_inlined_functions_helper( - else_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - Line::ForLoop { - body, unroll: _, .. - } => { - handle_inlined_functions_helper( - body, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - Line::Match { arms, .. } => { - for (_, arm) in arms { - handle_inlined_functions_helper( - arm, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - } - _ => {} - } - } -} - -fn handle_const_arguments(program: &mut Program) { - let mut new_functions = BTreeMap::::new(); - let constant_functions = program - .functions - .iter() - .filter(|(_, func)| func.has_const_arguments()) - .map(|(name, func)| (name.clone(), func.clone())) - .collect::>(); - - // First pass: process non-const functions that call const functions - for func in program.functions.values_mut() { - if !func.has_const_arguments() { - handle_const_arguments_helper(&mut func.body, &constant_functions, &mut new_functions); - } - } - - // Process newly created const functions recursively until no more changes - let mut changed = true; - let mut const_depth = 0; - while changed { - changed = false; - const_depth += 1; - assert!(const_depth < 100, "Too many levels of constant arguments"); - let mut additional_functions = BTreeMap::new(); - - // Collect all function names to process - let function_names: Vec = new_functions.keys().cloned().collect(); - - for name in function_names { - if let Some(func) = new_functions.get_mut(&name) { - let initial_count = additional_functions.len(); - handle_const_arguments_helper( - &mut func.body, - &constant_functions, - &mut additional_functions, - ); - if additional_functions.len() > initial_count { - changed = true; - } - } - } - - // Add any newly discovered functions - for (name, func) in additional_functions { - if let std::collections::btree_map::Entry::Vacant(e) = new_functions.entry(name) { - e.insert(func); - changed = true; - } - } - } - - for (name, func) in new_functions { - assert!(!program.functions.contains_key(&name),); - program.functions.insert(name, func); - } - for const_func in constant_functions.keys() { - program.functions.remove(const_func); - } -} - -fn handle_const_arguments_helper( - lines: &mut [Line], - constant_functions: &BTreeMap, - new_functions: &mut BTreeMap, -) { - for line in lines { - match line { - Line::FunctionCall { - function_name, - args, - return_data: _, - } => { - if let Some(func) = constant_functions.get(function_name) { - // If the function has constant arguments, we need to handle them - let mut const_evals = Vec::new(); - for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { - if *is_constant { - let const_eval = arg_expr.naive_eval().unwrap_or_else(|| { - panic!("Failed to evaluate constant argument: {arg_expr}") - }); - const_evals.push((arg_var.clone(), const_eval)); - } - } - let const_funct_name = format!( - "{function_name}_{}", - const_evals - .iter() - .map(|(arg_var, const_eval)| { format!("{arg_var}={const_eval}") }) - .collect::>() - .join("_") - ); - - *function_name = const_funct_name.clone(); // change the name of the function called - // ... and remove constant arguments - *args = args - .iter() - .zip(&func.arguments) - .filter(|(_, (_, is_constant))| !is_constant) - .filter(|(_, (_, is_const))| !is_const) - .map(|(arg_expr, _)| arg_expr.clone()) - .collect(); - - if new_functions.contains_key(&const_funct_name) { - continue; - } - - let mut new_body = func.body.clone(); - replace_vars_by_const_in_lines( - &mut new_body, - &const_evals.iter().cloned().collect(), - ); - new_functions.insert( - const_funct_name.clone(), - Function { - name: const_funct_name, - arguments: func - .arguments - .iter() - .filter(|(_, is_const)| !is_const) - .cloned() - .collect(), - inlined: false, - body: new_body, - n_returned_vars: func.n_returned_vars, - }, - ); - } - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - handle_const_arguments_helper(then_branch, constant_functions, new_functions); - handle_const_arguments_helper(else_branch, constant_functions, new_functions); - } - Line::ForLoop { - body, unroll: _, .. - } => { - // TODO we should unroll before const arguments handling - handle_const_arguments_helper(body, constant_functions, new_functions); - } - _ => {} - } - } -} - -fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) { - match expr { - Expression::Value(value) => match &value { - SimpleExpr::Var(var) => { - if let Some(const_value) = map.get(var) { - *value = SimpleExpr::scalar(const_value.to_usize()); - } - } - SimpleExpr::ConstMallocAccess { .. } => { - unreachable!() - } - SimpleExpr::Constant(_) => {} - }, - Expression::ArrayAccess { array, index } => { - if let SimpleExpr::Var(array_var) = array { - assert!( - !map.contains_key(array_var), - "Array {array_var} is a constant" - ); - } - replace_vars_by_const_in_expr(index, map); - } - Expression::Binary { left, right, .. } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } - Expression::Log2Ceil { value } => { - replace_vars_by_const_in_expr(value, map); - } - } -} - -fn get_function_called(lines: &[Line], function_called: &mut Vec) { - for line in lines { - match line { - Line::Match { value: _, arms } => { - for (_, statements) in arms { - get_function_called(statements, function_called); - } - } - Line::FunctionCall { function_name, .. } => { - function_called.push(function_name.clone()); - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - get_function_called(then_branch, function_called); - get_function_called(else_branch, function_called); - } - Line::ForLoop { body, .. } => { - get_function_called(body, function_called); - } - Line::Assignment { .. } - | Line::ArrayAssign { .. } - | Line::Assert { .. } - | Line::FunctionRet { .. } - | Line::Precompile { .. } - | Line::Print { .. } - | Line::DecomposeBits { .. } - | Line::CounterHint { .. } - | Line::MAlloc { .. } - | Line::Panic - | Line::Break - | Line::LocationReport { .. } => {} - } - } -} - -fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { - for line in lines { - match line { - Line::Match { value, arms } => { - replace_vars_by_const_in_expr(value, map); - for (_, statements) in arms { - replace_vars_by_const_in_lines(statements, map); - } - } - Line::Assignment { var, value } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - replace_vars_by_const_in_expr(value, map); - } - Line::ArrayAssign { - array, - index, - value, - } => { - if let SimpleExpr::Var(array_var) = array { - assert!( - !map.contains_key(array_var), - "Array {array_var} is a constant" - ); - } - replace_vars_by_const_in_expr(index, map); - replace_vars_by_const_in_expr(value, map); - } - Line::FunctionCall { - args, return_data, .. - } => { - for arg in args { - replace_vars_by_const_in_expr(arg, map); - } - for ret in return_data { - assert!( - !map.contains_key(ret), - "Return variable {ret} is a constant" - ); - } - } - Line::IfCondition { - condition, - then_branch, - else_branch, - } => { - match condition { - Boolean::Equal { left, right } | Boolean::Different { left, right } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } - } - replace_vars_by_const_in_lines(then_branch, map); - replace_vars_by_const_in_lines(else_branch, map); - } - Line::ForLoop { - body, start, end, .. - } => { - replace_vars_by_const_in_expr(start, map); - replace_vars_by_const_in_expr(end, map); - replace_vars_by_const_in_lines(body, map); - } - Line::Assert(condition) => match condition { - Boolean::Equal { left, right } | Boolean::Different { left, right } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } - }, - Line::FunctionRet { return_data } => { - for ret in return_data { - replace_vars_by_const_in_expr(ret, map); - } - } - Line::Precompile { - precompile: _, - args, - } => { - for arg in args { - replace_vars_by_const_in_expr(arg, map); - } - } - Line::Print { content, .. } => { - for var in content { - replace_vars_by_const_in_expr(var, map); - } - } - Line::DecomposeBits { var, to_decompose } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - for expr in to_decompose { - replace_vars_by_const_in_expr(expr, map); - } - } - Line::CounterHint { var } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - } - Line::MAlloc { var, size, .. } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - replace_vars_by_const_in_expr(size, map); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } -} -impl Display for SimpleLine { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.to_string_with_indent(0)) - } -} - -impl Display for VarOrConstMallocAccess { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Var(var) => write!(f, "{var}"), - Self::ConstMallocAccess { - malloc_label, - offset, - } => { - write!(f, "ConstMallocAccess({malloc_label}, {offset})") - } - } - } -} - -impl SimpleLine { - fn to_string_with_indent(&self, indent: usize) -> String { - let spaces = " ".repeat(indent); - let line_str = match self { - Self::Match { value, arms } => { - let arms_str = arms - .iter() - .enumerate() - .map(|(pattern, stmt)| { - format!( - "{} => {}", - pattern, - stmt.iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n") - ) - }) - .collect::>() - .join(", "); - - format!("match {value} {{\n{arms_str}\n{spaces}}}") - } - Self::Assignment { - var, - operation, - arg0, - arg1, - } => { - format!("{var} = {arg0} {operation} {arg1}") - } - Self::DecomposeBits { - var: result, - to_decompose, - label: _, - } => { - format!( - "{} = decompose_bits({})", - result, - to_decompose - .iter() - .map(|expr| format!("{expr}")) - .collect::>() - .join(", ") - ) - } - Self::CounterHint { var: result } => { - format!("{result} = counter_hint()") - } - Self::RawAccess { res, index, shift } => { - format!("memory[{index} + {shift}] = {res}") - } - Self::IfNotZero { - condition, - then_branch, - else_branch, - } => { - let then_str = then_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - let else_str = else_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - if else_branch.is_empty() { - format!("if {condition} != 0 {{\n{then_str}\n{spaces}}}") - } else { - format!( - "if {condition} != 0 {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" - ) - } - } - Self::FunctionCall { - function_name, - args, - return_data, - } => { - let args_str = args - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); - let return_data_str = return_data - .iter() - .map(|var| var.to_string()) - .collect::>() - .join(", "); - - if return_data.is_empty() { - format!("{function_name}({args_str})") - } else { - format!("{return_data_str} = {function_name}({args_str})") - } - } - Self::FunctionRet { return_data } => { - let return_data_str = return_data - .iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", "); - format!("return {return_data_str}") - } - Self::Precompile { precompile, args } => { - format!( - "{}({})", - &precompile.name, - args.iter() - .map(|arg| format!("{arg}")) - .collect::>() - .join(", ") - ) - } - Self::Print { - line_info: _, - content, - } => { - let content_str = content - .iter() - .map(|c| format!("{c}")) - .collect::>() - .join(", "); - format!("print({content_str})") - } - Self::HintMAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - if *vectorized { - format!("{var} = malloc_vec({size}, {vectorized_len})") - } else { - format!("{var} = malloc({size})") - } - } - Self::ConstMalloc { - var, - size, - label: _, - } => { - format!("{var} = malloc({size})") - } - Self::Panic => "panic".to_string(), - Self::LocationReport { .. } => Default::default(), - }; - format!("{spaces}{line_str}") - } -} - -impl Display for SimpleFunction { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let args_str = self - .arguments - .iter() - .map(|arg| arg.to_string()) - .collect::>() - .join(", "); - - let instructions_str = self - .instructions - .iter() - .map(|line| line.to_string_with_indent(1)) - .collect::>() - .join("\n"); - - if self.instructions.is_empty() { - write!( - f, - "fn {}({}) -> {} {{}}", - self.name, args_str, self.n_returned_vars - ) - } else { - write!( - f, - "fn {}({}) -> {} {{\n{}\n}}", - self.name, args_str, self.n_returned_vars, instructions_str - ) - } - } -} - -impl Display for SimpleProgram { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut first = true; - for function in self.functions.values() { - if !first { - writeln!(f)?; - } - write!(f, "{function}")?; - first = false; - } - Ok(()) - } -} diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 5d47ca2e..ae9096e9 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -1,4 +1,4 @@ -use crate::{F, a_simplify_lang::*, ir::*, lang::*, precompiles::*}; +use crate::{F, ir::*, lang::*, precompiles::*, simplify::*}; use lean_vm::*; use p3_field::Field; use std::{ diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 27f04ee7..37cdfba3 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -3,17 +3,18 @@ use std::collections::BTreeMap; use lean_vm::*; use crate::{ - a_simplify_lang::simplify_program, b_compile_intermediate::compile_to_intermediate_bytecode, + b_compile_intermediate::compile_to_intermediate_bytecode, c_compile_final::compile_to_low_level_bytecode, parser::parse_program, + simplify::simplify_program, }; -mod a_simplify_lang; mod b_compile_intermediate; mod c_compile_final; pub mod ir; mod lang; mod parser; mod precompiles; +mod simplify; pub use precompiles::PRECOMPILES; pub fn compile_program(program: &str) -> (Bytecode, BTreeMap) { diff --git a/crates/lean_compiler/src/simplify/mod.rs b/crates/lean_compiler/src/simplify/mod.rs new file mode 100644 index 00000000..41672e1e --- /dev/null +++ b/crates/lean_compiler/src/simplify/mod.rs @@ -0,0 +1,53 @@ +pub mod simplify; +pub mod transformations; +pub mod types; +pub mod unroll; +pub mod utilities; + +pub use types::{ConstMalloc, SimpleFunction, SimpleLine, SimpleProgram, VarOrConstMallocAccess}; + +use crate::lang::Program; +use std::collections::BTreeMap; +use types::{ArrayManager, Counters}; + +/// Main entry point for program simplification. +pub fn simplify_program(mut program: Program) -> SimpleProgram { + transformations::handle_inlined_functions(&mut program); + transformations::handle_const_arguments(&mut program); + let mut new_functions = BTreeMap::new(); + let mut counters = Counters::default(); + let mut const_malloc = ConstMalloc::default(); + + for (name, func) in &program.functions { + let mut array_manager = ArrayManager::default(); + let simplified_instructions = simplify::simplify_lines( + &func.body, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + let arguments = func + .arguments + .iter() + .map(|(v, is_const)| { + assert!(!is_const,); + v.clone() + }) + .collect::>(); + new_functions.insert( + name.clone(), + SimpleFunction { + name: name.clone(), + arguments, + n_returned_vars: func.n_returned_vars, + instructions: simplified_instructions, + }, + ); + const_malloc.map.clear(); + } + SimpleProgram { + functions: new_functions, + } +} diff --git a/crates/lean_compiler/src/simplify/simplify.rs b/crates/lean_compiler/src/simplify/simplify.rs new file mode 100644 index 00000000..6b6ededa --- /dev/null +++ b/crates/lean_compiler/src/simplify/simplify.rs @@ -0,0 +1,758 @@ +use super::{ + types::{ + ArrayAccessType, ArrayManager, ConstMalloc, Counters, SimpleFunction, SimpleLine, + VarOrConstMallocAccess, + }, + utilities::find_variable_usage, +}; +use crate::{ + ir::HighLevelOperation, + lang::{Boolean, ConstExpression, Expression, Line, SimpleExpr, Var}, +}; +use std::collections::BTreeMap; +use utils::ToUsize; + +/// Simplify a sequence of lines into SimpleLine format. +pub fn simplify_lines( + lines: &[Line], + counters: &mut Counters, + new_functions: &mut BTreeMap, + in_a_loop: bool, + array_manager: &mut ArrayManager, + const_malloc: &mut ConstMalloc, +) -> Vec { + let mut res = Vec::new(); + for line in lines { + match line { + Line::Match { value, arms } => { + let simple_value = + simplify_expr(value, &mut res, counters, array_manager, const_malloc); + let mut simple_arms = vec![]; + for (i, (pattern, statements)) in arms.iter().enumerate() { + assert_eq!( + *pattern, i, + "match patterns should be consecutive, starting from 0" + ); + simple_arms.push(simplify_lines( + statements, + counters, + new_functions, + in_a_loop, + array_manager, + const_malloc, + )); + } + res.push(SimpleLine::Match { + value: simple_value, + arms: simple_arms, + }); + } + Line::Assignment { var, value } => match value { + Expression::Value(value) => { + res.push(SimpleLine::Assignment { + var: var.clone().into(), + operation: HighLevelOperation::Add, + arg0: value.clone(), + arg1: SimpleExpr::zero(), + }); + } + Expression::ArrayAccess { array, index } => { + handle_array_assignment( + counters, + &mut res, + array.clone(), + index, + ArrayAccessType::VarIsAssigned(var.clone()), + array_manager, + const_malloc, + ); + } + Expression::Binary { + left, + operation, + right, + } => { + let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); + let right = + simplify_expr(right, &mut res, counters, array_manager, const_malloc); + res.push(SimpleLine::Assignment { + var: var.clone().into(), + operation: *operation, + arg0: left, + arg1: right, + }); + } + Expression::Log2Ceil { .. } => unreachable!(), + }, + Line::ArrayAssign { + array, + index, + value, + } => { + handle_array_assignment( + counters, + &mut res, + array.clone(), + index, + ArrayAccessType::ArrayIsAssigned(value.clone()), + array_manager, + const_malloc, + ); + } + Line::Assert(boolean) => match boolean { + Boolean::Different { left, right } => { + let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); + let right = + simplify_expr(right, &mut res, counters, array_manager, const_malloc); + let diff_var = format!("@aux_var_{}", counters.aux_vars); + counters.aux_vars += 1; + res.push(SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: HighLevelOperation::Sub, + arg0: left, + arg1: right, + }); + res.push(SimpleLine::IfNotZero { + condition: diff_var.into(), + then_branch: vec![], + else_branch: vec![SimpleLine::Panic], + }); + } + Boolean::Equal { left, right } => { + let left = simplify_expr(left, &mut res, counters, array_manager, const_malloc); + let right = + simplify_expr(right, &mut res, counters, array_manager, const_malloc); + let (var, other) = if let Ok(left) = left.clone().try_into() { + (left, right) + } else if let Ok(right) = right.clone().try_into() { + (right, left) + } else { + unreachable!("Weird: {:?}, {:?}", left, right) + }; + res.push(SimpleLine::Assignment { + var, + operation: HighLevelOperation::Add, + arg0: other, + arg1: SimpleExpr::zero(), + }); + } + }, + Line::IfCondition { + condition, + then_branch, + else_branch, + } => { + handle_if_condition( + condition, + then_branch, + else_branch, + &mut res, + counters, + new_functions, + in_a_loop, + array_manager, + const_malloc, + ); + } + Line::ForLoop { + iterator, + start, + end, + body, + rev, + unroll, + } => { + handle_for_loop( + iterator, + start, + end, + body, + *rev, + *unroll, + &mut res, + counters, + new_functions, + in_a_loop, + array_manager, + const_malloc, + ); + } + Line::FunctionCall { + function_name, + args, + return_data, + } => { + let simplified_args = args + .iter() + .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) + .collect::>(); + res.push(SimpleLine::FunctionCall { + function_name: function_name.clone(), + args: simplified_args, + return_data: return_data.clone(), + }); + } + Line::FunctionRet { return_data } => { + assert!( + !in_a_loop, + "Function return inside a loop is not currently supported" + ); + let simplified_return_data = return_data + .iter() + .map(|ret| simplify_expr(ret, &mut res, counters, array_manager, const_malloc)) + .collect::>(); + res.push(SimpleLine::FunctionRet { + return_data: simplified_return_data, + }); + } + Line::Precompile { precompile, args } => { + let simplified_args = args + .iter() + .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) + .collect::>(); + res.push(SimpleLine::Precompile { + precompile: precompile.clone(), + args: simplified_args, + }); + } + Line::Print { line_info, content } => { + let simplified_content = content + .iter() + .map(|var| simplify_expr(var, &mut res, counters, array_manager, const_malloc)) + .collect::>(); + res.push(SimpleLine::Print { + line_info: line_info.clone(), + content: simplified_content, + }); + } + Line::Break => { + assert!(in_a_loop, "Break statement outside of a loop"); + res.push(SimpleLine::FunctionRet { + return_data: vec![], + }); + } + Line::MAlloc { + var, + size, + vectorized, + vectorized_len, + } => { + handle_malloc( + var, + size, + *vectorized, + vectorized_len, + &mut res, + counters, + array_manager, + const_malloc, + ); + } + Line::DecomposeBits { var, to_decompose } => { + assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); + let simplified_to_decompose = to_decompose + .iter() + .map(|expr| { + simplify_expr(expr, &mut res, counters, array_manager, const_malloc) + }) + .collect::>(); + let label = const_malloc.counter; + const_malloc.counter += 1; + const_malloc.map.insert(var.clone(), label); + res.push(SimpleLine::DecomposeBits { + var: var.clone(), + to_decompose: simplified_to_decompose, + label, + }); + } + Line::CounterHint { var } => { + res.push(SimpleLine::CounterHint { var: var.clone() }); + } + Line::Panic => { + res.push(SimpleLine::Panic); + } + Line::LocationReport { location } => { + res.push(SimpleLine::LocationReport { + location: *location, + }); + } + } + } + + res +} + +/// Simplify an expression into SimpleExpr format. +pub fn simplify_expr( + expr: &Expression, + lines: &mut Vec, + counters: &mut Counters, + array_manager: &mut ArrayManager, + const_malloc: &ConstMalloc, +) -> SimpleExpr { + match expr { + Expression::Value(value) => value.simplify_if_const(), + Expression::ArrayAccess { array, index } => { + if let SimpleExpr::Var(array_var) = array + && let Some(label) = const_malloc.map.get(array_var) + && let Ok(mut offset) = ConstExpression::try_from(*index.clone()) + { + offset = offset.try_naive_simplification(); + return SimpleExpr::ConstMallocAccess { + malloc_label: *label, + offset, + }; + } + + let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index] + + if !array_manager.valid.insert(aux_arr.clone()) { + return SimpleExpr::Var(aux_arr); + } + + handle_array_assignment( + counters, + lines, + array.clone(), + index, + ArrayAccessType::VarIsAssigned(aux_arr.clone()), + array_manager, + const_malloc, + ); + SimpleExpr::Var(aux_arr) + } + Expression::Binary { + left, + operation, + right, + } => { + let left_var = simplify_expr(left, lines, counters, array_manager, const_malloc); + let right_var = simplify_expr(right, lines, counters, array_manager, const_malloc); + + if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = + (&left_var, &right_var) + { + return SimpleExpr::Constant(ConstExpression::Binary { + left: Box::new(left_cst.clone()), + operation: *operation, + right: Box::new(right_cst.clone()), + }); + } + + let aux_var = format!("@aux_var_{}", counters.aux_vars); + counters.aux_vars += 1; + lines.push(SimpleLine::Assignment { + var: aux_var.clone().into(), + operation: *operation, + arg0: left_var, + arg1: right_var, + }); + SimpleExpr::Var(aux_var) + } + Expression::Log2Ceil { value } => { + let const_value = simplify_expr(value, lines, counters, array_manager, const_malloc) + .as_constant() + .unwrap(); + SimpleExpr::Constant(ConstExpression::Log2Ceil { + value: Box::new(const_value), + }) + } + } +} + +fn handle_if_condition( + condition: &Boolean, + then_branch: &[Line], + else_branch: &[Line], + res: &mut Vec, + counters: &mut Counters, + new_functions: &mut BTreeMap, + in_a_loop: bool, + array_manager: &mut ArrayManager, + const_malloc: &mut ConstMalloc, +) { + // Transform if a == b then X else Y into if a != b then Y else X + let (left, right, then_branch, else_branch) = match condition { + Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched + Boolean::Different { left, right } => (left, right, then_branch, else_branch), + }; + + let left_simplified = simplify_expr(left, res, counters, array_manager, const_malloc); + let right_simplified = simplify_expr(right, res, counters, array_manager, const_malloc); + + let diff_var = format!("@diff_{}", counters.aux_vars); + counters.aux_vars += 1; + res.push(SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: HighLevelOperation::Sub, + arg0: left_simplified, + arg1: right_simplified, + }); + + let forbidden_vars_before = const_malloc.forbidden_vars.clone(); + + let then_internal_vars = find_variable_usage(then_branch).0; + let else_internal_vars = find_variable_usage(else_branch).0; + let new_forbidden_vars = then_internal_vars + .intersection(&else_internal_vars) + .cloned() + .collect::>(); + + const_malloc.forbidden_vars.extend(new_forbidden_vars); + + let mut array_manager_then = array_manager.clone(); + let then_branch_simplified = simplify_lines( + then_branch, + counters, + new_functions, + in_a_loop, + &mut array_manager_then, + const_malloc, + ); + let mut array_manager_else = array_manager_then.clone(); + array_manager_else.valid = array_manager.valid.clone(); // Crucial: remove the access added in the IF branch + + let else_branch_simplified = simplify_lines( + else_branch, + counters, + new_functions, + in_a_loop, + &mut array_manager_else, + const_malloc, + ); + + const_malloc.forbidden_vars = forbidden_vars_before; + + *array_manager = array_manager_else.clone(); + // keep the intersection both branches + array_manager.valid = array_manager + .valid + .intersection(&array_manager_then.valid) + .cloned() + .collect(); + + res.push(SimpleLine::IfNotZero { + condition: diff_var.into(), + then_branch: then_branch_simplified, + else_branch: else_branch_simplified, + }); +} + +fn handle_for_loop( + iterator: &Var, + start: &Expression, + end: &Expression, + body: &[Line], + rev: bool, + unroll: bool, + res: &mut Vec, + counters: &mut Counters, + new_functions: &mut BTreeMap, + in_a_loop: bool, + array_manager: &mut ArrayManager, + const_malloc: &mut ConstMalloc, +) { + if unroll { + handle_unrolled_loop( + iterator, + start, + end, + body, + rev, + res, + counters, + new_functions, + in_a_loop, + array_manager, + const_malloc, + ); + return; + } + + if rev { + unimplemented!("Reverse for non-unrolled loops are not implemented yet"); + } + + let mut loop_const_malloc = ConstMalloc { + counter: const_malloc.counter, + ..ConstMalloc::default() + }; + let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); + array_manager.valid.clear(); + let simplified_body = simplify_lines( + body, + counters, + new_functions, + true, + array_manager, + &mut loop_const_malloc, + ); + const_malloc.counter = loop_const_malloc.counter; + array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars + + let func_name = format!("@loop_{}", counters.loops); + counters.loops += 1; + + // Find variables used inside loop but defined outside + let (_, mut external_vars) = find_variable_usage(body); + + // Include variables in start/end + for expr in [start, end] { + for var in crate::simplify::utilities::vars_in_expression(expr) { + external_vars.insert(var); + } + } + external_vars.remove(iterator); // Iterator is internal to loop + + let mut external_vars: Vec<_> = external_vars.into_iter().collect(); + + let start_simplified = simplify_expr(start, res, counters, array_manager, const_malloc); + let end_simplified = simplify_expr(end, res, counters, array_manager, const_malloc); + + for (simplified, original) in [ + (start_simplified.clone(), start.clone()), + (end_simplified.clone(), end.clone()), + ] { + if !matches!(original, Expression::Value(_)) { + // the simplified var is auxiliary + if let SimpleExpr::Var(var) = simplified { + external_vars.push(var); + } + } + } + + // Create function arguments: iterator + external variables + let mut func_args = vec![iterator.clone()]; + func_args.extend(external_vars.clone()); + + // Create recursive function body + let recursive_func = create_recursive_function( + func_name.clone(), + func_args, + iterator.clone(), + end_simplified, + simplified_body, + &external_vars, + ); + new_functions.insert(func_name.clone(), recursive_func); + + // Replace loop with initial function call + let mut call_args = vec![start_simplified]; + call_args.extend(external_vars.iter().map(|v| v.clone().into())); + + res.push(SimpleLine::FunctionCall { + function_name: func_name, + args: call_args, + return_data: vec![], + }); +} + +fn handle_unrolled_loop( + iterator: &Var, + start: &Expression, + end: &Expression, + body: &[Line], + rev: bool, + res: &mut Vec, + counters: &mut Counters, + new_functions: &mut BTreeMap, + in_a_loop: bool, + array_manager: &mut ArrayManager, + const_malloc: &mut ConstMalloc, +) { + let (internal_variables, _) = find_variable_usage(body); + let mut unrolled_lines = Vec::new(); + let start_evaluated = start.naive_eval().unwrap().to_usize(); + let end_evaluated = end.naive_eval().unwrap().to_usize(); + let unroll_index = counters.unrolls; + counters.unrolls += 1; + + let mut range = (start_evaluated..end_evaluated).collect::>(); + if rev { + range.reverse(); + } + + for i in range { + let mut body_copy = body.to_vec(); + super::unroll::replace_vars_for_unroll( + &mut body_copy, + iterator, + unroll_index, + i, + &internal_variables, + ); + unrolled_lines.extend(simplify_lines( + &body_copy, + counters, + new_functions, + in_a_loop, + array_manager, + const_malloc, + )); + } + res.extend(unrolled_lines); +} + +fn handle_malloc( + var: &Var, + size: &Expression, + vectorized: bool, + vectorized_len: &Expression, + res: &mut Vec, + counters: &mut Counters, + array_manager: &mut ArrayManager, + const_malloc: &mut ConstMalloc, +) { + let simplified_size = simplify_expr(size, res, counters, array_manager, const_malloc); + let simplified_vectorized_len = + simplify_expr(vectorized_len, res, counters, array_manager, const_malloc); + if simplified_size.is_constant() && !vectorized && const_malloc.forbidden_vars.contains(var) { + println!("TODO: Optimization missed: Requires to align const malloc in if/else branches"); + } + match simplified_size { + SimpleExpr::Constant(const_size) + if !vectorized && !const_malloc.forbidden_vars.contains(var) => + { + // TODO do this optimization even if we are in an if/else branch + let label = const_malloc.counter; + const_malloc.counter += 1; + const_malloc.map.insert(var.clone(), label); + res.push(SimpleLine::ConstMalloc { + var: var.clone(), + size: const_size, + label, + }); + } + _ => { + res.push(SimpleLine::HintMAlloc { + var: var.clone(), + size: simplified_size, + vectorized: vectorized, + vectorized_len: simplified_vectorized_len, + }); + } + } +} + +/// Handle array access assignment operations. +pub fn handle_array_assignment( + counters: &mut Counters, + res: &mut Vec, + array: SimpleExpr, + index: &Expression, + access_type: ArrayAccessType, + array_manager: &mut ArrayManager, + const_malloc: &ConstMalloc, +) { + let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc); + + if let SimpleExpr::Constant(offset) = simplified_index.clone() + && let SimpleExpr::Var(array_var) = &array + && let Some(label) = const_malloc.map.get(array_var) + && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { + left, + operation, + right, + }) = &access_type + { + let arg0 = simplify_expr(left, res, counters, array_manager, const_malloc); + let arg1 = simplify_expr(right, res, counters, array_manager, const_malloc); + res.push(SimpleLine::Assignment { + var: VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *label, + offset, + }, + operation: *operation, + arg0, + arg1, + }); + return; + } + + let value_simplified = match access_type { + ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Var(var), + ArrayAccessType::ArrayIsAssigned(expr) => { + simplify_expr(&expr, res, counters, array_manager, const_malloc) + } + }; + + // TODO opti: in some case we could use ConstMallocAccess + + let (index_var, shift) = match simplified_index { + SimpleExpr::Constant(c) => (array, c), + _ => { + // Create pointer variable: ptr = array + index + let ptr_var = format!("@aux_var_{}", counters.aux_vars); + counters.aux_vars += 1; + res.push(SimpleLine::Assignment { + var: ptr_var.clone().into(), + operation: HighLevelOperation::Add, + arg0: array, + arg1: simplified_index, + }); + (SimpleExpr::Var(ptr_var), ConstExpression::zero()) + } + }; + + res.push(SimpleLine::RawAccess { + res: value_simplified, + index: index_var, + shift, + }); +} + +fn create_recursive_function( + name: String, + args: Vec, + iterator: Var, + end: SimpleExpr, + mut body: Vec, + external_vars: &[Var], +) -> SimpleFunction { + // Add iterator increment + let next_iter = format!("@incremented_{iterator}"); + body.push(SimpleLine::Assignment { + var: next_iter.clone().into(), + operation: HighLevelOperation::Add, + arg0: iterator.clone().into(), + arg1: SimpleExpr::one(), + }); + + // Add recursive call + let mut recursive_args: Vec = vec![next_iter.into()]; + recursive_args.extend(external_vars.iter().map(|v| v.clone().into())); + + body.push(SimpleLine::FunctionCall { + function_name: name.clone(), + args: recursive_args, + return_data: vec![], + }); + body.push(SimpleLine::FunctionRet { + return_data: vec![], + }); + + let diff_var = format!("@diff_{iterator}"); + + let instructions = vec![ + SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: HighLevelOperation::Sub, + arg0: iterator.into(), + arg1: end, + }, + SimpleLine::IfNotZero { + condition: diff_var.into(), + then_branch: body, + else_branch: vec![SimpleLine::FunctionRet { + return_data: vec![], + }], + }, + ]; + + SimpleFunction { + name, + arguments: args, + n_returned_vars: 0, + instructions, + } +} diff --git a/crates/lean_compiler/src/simplify/transformations.rs b/crates/lean_compiler/src/simplify/transformations.rs new file mode 100644 index 00000000..9347f684 --- /dev/null +++ b/crates/lean_compiler/src/simplify/transformations.rs @@ -0,0 +1,492 @@ +use super::utilities::replace_vars_by_const_in_lines; +use crate::{ + Counter, + lang::{Boolean, Expression, Function, Line, Program, SimpleExpr, Var}, +}; +use std::collections::BTreeMap; + +/// Handle inlined functions by replacing calls with function body. +pub fn handle_inlined_functions(program: &mut Program) { + let inlined_functions = program + .functions + .iter() + .filter(|(_, func)| func.inlined) + .map(|(name, func)| (name.clone(), func.clone())) + .collect::>(); + + for func in inlined_functions.values() { + assert!( + !func.has_const_arguments(), + "Inlined functions with constant arguments are not supported yet" + ); + } + + // Process inline functions iteratively to handle dependencies + // Repeat until all inline function calls are resolved + let mut max_iterations = 10; + while max_iterations > 0 { + let mut any_changes = false; + + // Process non-inlined functions + for func in program.functions.values_mut() { + if !func.inlined { + let mut counter1 = Counter::new(); + let mut counter2 = Counter::new(); + let old_body = func.body.clone(); + + handle_inlined_functions_helper( + &mut func.body, + &inlined_functions, + &mut counter1, + &mut counter2, + ); + + if func.body != old_body { + any_changes = true; + } + } + } + + // Process inlined functions that may call other inlined functions + // We need to update them so that when they get inlined later, they don't have unresolved calls + for func in program.functions.values_mut() { + if func.inlined { + let mut counter1 = Counter::new(); + let mut counter2 = Counter::new(); + let old_body = func.body.clone(); + + handle_inlined_functions_helper( + &mut func.body, + &inlined_functions, + &mut counter1, + &mut counter2, + ); + + if func.body != old_body { + any_changes = true; + } + } + } + + if !any_changes { + break; + } + + max_iterations -= 1; + } + + assert!( + max_iterations > 0, + "Too many iterations processing inline functions" + ); + + // Remove all inlined functions from the program (they've been inlined) + for func_name in inlined_functions.keys() { + program.functions.remove(func_name); + } +} + +fn handle_inlined_functions_helper( + lines: &mut Vec, + inlined_functions: &BTreeMap, + inlined_var_counter: &mut Counter, + total_inlined_counter: &mut Counter, +) { + for i in (0..lines.len()).rev() { + match &mut lines[i] { + Line::FunctionCall { + function_name, + args, + return_data, + } => { + if let Some(func) = inlined_functions.get(&*function_name) { + let mut inlined_lines = vec![]; + + let mut simplified_args = vec![]; + for arg in args { + if let Expression::Value(simple_expr) = arg { + simplified_args.push(simple_expr.clone()); + } else { + let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); + inlined_lines.push(Line::Assignment { + var: aux_var.clone(), + value: arg.clone(), + }); + simplified_args.push(SimpleExpr::Var(aux_var)); + } + } + assert_eq!(simplified_args.len(), func.arguments.len()); + let inlined_args = func + .arguments + .iter() + .zip(&simplified_args) + .map(|((var, _), expr)| (var.clone(), expr.clone())) + .collect::>(); + let mut func_body = func.body.clone(); + inline_lines( + &mut func_body, + &inlined_args, + return_data, + total_inlined_counter.next(), + ); + inlined_lines.extend(func_body); + + lines.remove(i); // remove the call to the inlined function + lines.splice(i..i, inlined_lines); + } + } + Line::IfCondition { + then_branch, + else_branch, + .. + } => { + handle_inlined_functions_helper( + then_branch, + inlined_functions, + inlined_var_counter, + total_inlined_counter, + ); + handle_inlined_functions_helper( + else_branch, + inlined_functions, + inlined_var_counter, + total_inlined_counter, + ); + } + Line::ForLoop { + body, unroll: _, .. + } => { + handle_inlined_functions_helper( + body, + inlined_functions, + inlined_var_counter, + total_inlined_counter, + ); + } + Line::Match { arms, .. } => { + for (_, arm) in arms { + handle_inlined_functions_helper( + arm, + inlined_functions, + inlined_var_counter, + total_inlined_counter, + ); + } + } + _ => {} + } + } +} + +/// Handle functions with constant arguments by creating specialized versions. +pub fn handle_const_arguments(program: &mut Program) { + let mut new_functions = BTreeMap::::new(); + let constant_functions = program + .functions + .iter() + .filter(|(_, func)| func.has_const_arguments()) + .map(|(name, func)| (name.clone(), func.clone())) + .collect::>(); + + // First pass: process non-const functions that call const functions + for func in program.functions.values_mut() { + if !func.has_const_arguments() { + handle_const_arguments_helper(&mut func.body, &constant_functions, &mut new_functions); + } + } + + // Process newly created const functions recursively until no more changes + let mut changed = true; + let mut const_depth = 0; + while changed { + changed = false; + const_depth += 1; + assert!(const_depth < 100, "Too many levels of constant arguments"); + let mut additional_functions = BTreeMap::new(); + + // Collect all function names to process + let function_names: Vec = new_functions.keys().cloned().collect(); + + for name in function_names { + if let Some(func) = new_functions.get_mut(&name) { + let initial_count = additional_functions.len(); + handle_const_arguments_helper( + &mut func.body, + &constant_functions, + &mut additional_functions, + ); + if additional_functions.len() > initial_count { + changed = true; + } + } + } + + // Add any newly discovered functions + for (name, func) in additional_functions { + if let std::collections::btree_map::Entry::Vacant(e) = new_functions.entry(name) { + e.insert(func); + changed = true; + } + } + } + + for (name, func) in new_functions { + assert!(!program.functions.contains_key(&name),); + program.functions.insert(name, func); + } + for const_func in constant_functions.keys() { + program.functions.remove(const_func); + } +} + +fn handle_const_arguments_helper( + lines: &mut [Line], + constant_functions: &BTreeMap, + new_functions: &mut BTreeMap, +) { + for line in lines { + match line { + Line::FunctionCall { + function_name, + args, + return_data: _, + } => { + if let Some(func) = constant_functions.get(function_name) { + // If the function has constant arguments, we need to handle them + let mut const_evals = Vec::new(); + for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { + if *is_constant { + let const_eval = arg_expr.naive_eval().unwrap_or_else(|| { + panic!("Failed to evaluate constant argument: {arg_expr}") + }); + const_evals.push((arg_var.clone(), const_eval)); + } + } + let const_funct_name = format!( + "{function_name}_{}", + const_evals + .iter() + .map(|(arg_var, const_eval)| { format!("{arg_var}={const_eval}") }) + .collect::>() + .join("_") + ); + + *function_name = const_funct_name.clone(); // change the name of the function called + // ... and remove constant arguments + *args = args + .iter() + .zip(&func.arguments) + .filter(|(_, (_, is_constant))| !is_constant) + .filter(|(_, (_, is_const))| !is_const) + .map(|(arg_expr, _)| arg_expr.clone()) + .collect(); + + if new_functions.contains_key(&const_funct_name) { + continue; + } + + let mut new_body = func.body.clone(); + replace_vars_by_const_in_lines( + &mut new_body, + &const_evals.iter().cloned().collect(), + ); + new_functions.insert( + const_funct_name.clone(), + Function { + name: const_funct_name, + arguments: func + .arguments + .iter() + .filter(|(_, is_const)| !is_const) + .cloned() + .collect(), + inlined: false, + body: new_body, + n_returned_vars: func.n_returned_vars, + }, + ); + } + } + Line::IfCondition { + then_branch, + else_branch, + .. + } => { + handle_const_arguments_helper(then_branch, constant_functions, new_functions); + handle_const_arguments_helper(else_branch, constant_functions, new_functions); + } + Line::ForLoop { + body, unroll: _, .. + } => { + // TODO we should unroll before const arguments handling + handle_const_arguments_helper(body, constant_functions, new_functions); + } + _ => {} + } + } +} + +/// Inline function bodies at call sites. +pub fn inline_lines( + lines: &mut Vec, + args: &BTreeMap, + res: &[Var], + inlining_count: usize, +) { + let inline_condition = |condition: &mut Boolean| { + let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; + inline_expr(left, args, inlining_count); + inline_expr(right, args, inlining_count); + }; + + let inline_internal_var = |var: &mut Var| { + assert!( + !args.contains_key(var), + "Variable {var} is both an argument and assigned in the inlined function" + ); + *var = format!("@inlined_var_{inlining_count}_{var}"); + }; + + let mut lines_to_replace = vec![]; + for (i, line) in lines.iter_mut().enumerate() { + match line { + Line::Match { value, arms } => { + inline_expr(value, args, inlining_count); + for (_, statements) in arms { + inline_lines(statements, args, res, inlining_count); + } + } + Line::Assignment { var, value } => { + inline_expr(value, args, inlining_count); + inline_internal_var(var); + } + Line::IfCondition { + condition, + then_branch, + else_branch, + } => { + inline_condition(condition); + + inline_lines(then_branch, args, res, inlining_count); + inline_lines(else_branch, args, res, inlining_count); + } + Line::FunctionCall { + args: func_args, + return_data, + .. + } => { + for arg in func_args { + inline_expr(arg, args, inlining_count); + } + for return_var in return_data { + inline_internal_var(return_var); + } + } + Line::Assert(condition) => { + inline_condition(condition); + } + Line::FunctionRet { return_data } => { + assert_eq!(return_data.len(), res.len()); + + for expr in return_data.iter_mut() { + inline_expr(expr, args, inlining_count); + } + lines_to_replace.push(( + i, + res.iter() + .zip(return_data) + .map(|(res_var, expr)| Line::Assignment { + var: res_var.clone(), + value: expr.clone(), + }) + .collect::>(), + )); + } + Line::MAlloc { var, size, .. } => { + inline_expr(size, args, inlining_count); + inline_internal_var(var); + } + Line::Precompile { + precompile: _, + args: precompile_args, + } => { + for arg in precompile_args { + inline_expr(arg, args, inlining_count); + } + } + Line::Print { content, .. } => { + for var in content { + inline_expr(var, args, inlining_count); + } + } + Line::DecomposeBits { var, to_decompose } => { + for expr in to_decompose { + inline_expr(expr, args, inlining_count); + } + inline_internal_var(var); + } + Line::CounterHint { var } => { + inline_internal_var(var); + } + Line::ForLoop { + iterator, + start, + end, + body, + rev: _, + unroll: _, + } => { + inline_lines(body, args, res, inlining_count); + inline_internal_var(iterator); + inline_expr(start, args, inlining_count); + inline_expr(end, args, inlining_count); + } + Line::ArrayAssign { + array, + index, + value, + } => { + inline_simple_expr(array, args, inlining_count); + inline_expr(index, args, inlining_count); + inline_expr(value, args, inlining_count); + } + Line::Panic | Line::Break | Line::LocationReport { .. } => {} + } + } + for (i, new_lines) in lines_to_replace.into_iter().rev() { + lines.splice(i..=i, new_lines); + } +} + +fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining_count: usize) { + match expr { + Expression::Value(value) => { + inline_simple_expr(value, args, inlining_count); + } + Expression::ArrayAccess { array, index } => { + inline_simple_expr(array, args, inlining_count); + inline_expr(index, args, inlining_count); + } + Expression::Binary { left, right, .. } => { + inline_expr(left, args, inlining_count); + inline_expr(right, args, inlining_count); + } + Expression::Log2Ceil { value } => { + inline_expr(value, args, inlining_count); + } + } +} + +fn inline_simple_expr( + simple_expr: &mut SimpleExpr, + args: &BTreeMap, + inlining_count: usize, +) { + if let SimpleExpr::Var(var) = simple_expr { + if let Some(replacement) = args.get(var) { + *simple_expr = replacement.clone(); + } else { + *var = format!("@inlined_var_{inlining_count}_{var}"); + } + } +} diff --git a/crates/lean_compiler/src/simplify/types.rs b/crates/lean_compiler/src/simplify/types.rs new file mode 100644 index 00000000..4f616f13 --- /dev/null +++ b/crates/lean_compiler/src/simplify/types.rs @@ -0,0 +1,861 @@ +use crate::{ + ir::HighLevelOperation, + lang::{ConstExpression, ConstMallocLabel, SimpleExpr, Var}, + precompiles::Precompile, +}; +use lean_vm::SourceLineNumber; +use std::{ + collections::{BTreeMap, BTreeSet}, + fmt::{Display, Formatter}, +}; + +/// Simplified program representation after language simplification. +#[derive(Debug, Clone)] +pub struct SimpleProgram { + pub functions: BTreeMap, +} + +/// Simplified function representation. +#[derive(Debug, Clone)] +pub struct SimpleFunction { + pub name: String, + pub arguments: Vec, + pub n_returned_vars: usize, + pub instructions: Vec, +} + +/// Variable or constant malloc access for assignments. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum VarOrConstMallocAccess { + Var(Var), + ConstMallocAccess { + malloc_label: ConstMallocLabel, + offset: ConstExpression, + }, +} + +impl From for SimpleExpr { + fn from(var_or_const: VarOrConstMallocAccess) -> Self { + match var_or_const { + VarOrConstMallocAccess::Var(var) => Self::Var(var), + VarOrConstMallocAccess::ConstMallocAccess { + malloc_label, + offset, + } => Self::ConstMallocAccess { + malloc_label, + offset, + }, + } + } +} + +impl TryInto for SimpleExpr { + type Error = (); + + fn try_into(self) -> Result { + match self { + Self::Var(var) => Ok(VarOrConstMallocAccess::Var(var)), + Self::ConstMallocAccess { + malloc_label, + offset, + } => Ok(VarOrConstMallocAccess::ConstMallocAccess { + malloc_label, + offset, + }), + _ => Err(()), + } + } +} + +impl From for VarOrConstMallocAccess { + fn from(var: Var) -> Self { + Self::Var(var) + } +} + +/// Simplified language instruction representation. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SimpleLine { + Match { + value: SimpleExpr, + arms: Vec>, // patterns = 0, 1, ... + }, + Assignment { + var: VarOrConstMallocAccess, + operation: HighLevelOperation, + arg0: SimpleExpr, + arg1: SimpleExpr, + }, + RawAccess { + res: SimpleExpr, + index: SimpleExpr, + shift: ConstExpression, + }, // res = memory[index + shift] + IfNotZero { + condition: SimpleExpr, + then_branch: Vec, + else_branch: Vec, + }, + FunctionCall { + function_name: String, + args: Vec, + return_data: Vec, + }, + FunctionRet { + return_data: Vec, + }, + Precompile { + precompile: Precompile, + args: Vec, + }, + Panic, + // Hints + DecomposeBits { + var: Var, // a pointer to 31 * len(to_decompose) field elements, containing the bits of "to_decompose" + to_decompose: Vec, + label: ConstMallocLabel, + }, + CounterHint { + var: Var, + }, + Print { + line_info: String, + content: Vec, + }, + HintMAlloc { + var: Var, + size: SimpleExpr, + vectorized: bool, + vectorized_len: SimpleExpr, + }, + ConstMalloc { + // always not vectorized + var: Var, + size: ConstExpression, + label: ConstMallocLabel, + }, + // noop, debug purpose only + LocationReport { + location: SourceLineNumber, + }, +} + +/// Helper enum for array access operations. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ArrayAccessType { + VarIsAssigned(Var), // var = array[index] + ArrayIsAssigned(crate::lang::Expression), // array[index] = expr +} + +/// Internal state for various counters during simplification. +#[derive(Debug, Clone, Default)] +pub(crate) struct Counters { + pub(crate) aux_vars: usize, + pub(crate) loops: usize, + pub(crate) unrolls: usize, +} + +/// Array access management for optimization. +#[derive(Debug, Clone, Default)] +pub(crate) struct ArrayManager { + pub(crate) counter: usize, + pub(crate) aux_vars: BTreeMap<(SimpleExpr, crate::lang::Expression), Var>, // (array, index) -> aux_var + pub(crate) valid: BTreeSet, // currently valid aux vars +} + +impl ArrayManager { + pub(crate) fn get_aux_var( + &mut self, + array: &SimpleExpr, + index: &crate::lang::Expression, + ) -> Var { + if let Some(var) = self.aux_vars.get(&(array.clone(), index.clone())) { + return var.clone(); + } + let new_var = format!("@arr_aux_{}", self.counter); + self.counter += 1; + self.aux_vars + .insert((array.clone(), index.clone()), new_var.clone()); + new_var + } +} + +/// Constant malloc optimization state. +#[derive(Debug, Clone, Default)] +pub struct ConstMalloc { + pub(crate) counter: usize, + pub(crate) map: BTreeMap, + pub(crate) forbidden_vars: BTreeSet, // vars shared between branches of an if/else +} + +// Display implementations for types +impl Display for VarOrConstMallocAccess { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Var(var) => write!(f, "{var}"), + Self::ConstMallocAccess { + malloc_label, + offset, + } => { + write!(f, "ConstMallocAccess({malloc_label}, {offset})") + } + } + } +} + +impl Display for SimpleLine { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_string_with_indent(0)) + } +} + +impl SimpleLine { + fn to_string_with_indent(&self, indent: usize) -> String { + let spaces = " ".repeat(indent); + let line_str = match self { + Self::Match { value, arms } => { + let arms_str = arms + .iter() + .enumerate() + .map(|(pattern, stmt)| { + format!( + "{} => {}", + pattern, + stmt.iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n") + ) + }) + .collect::>() + .join(", "); + + format!("match {value} {{\n{arms_str}\n{spaces}}}") + } + Self::Assignment { + var, + operation, + arg0, + arg1, + } => { + format!("{var} = {arg0} {operation} {arg1}") + } + Self::DecomposeBits { + var: result, + to_decompose, + label: _, + } => { + format!( + "{} = decompose_bits({})", + result, + to_decompose + .iter() + .map(|expr| format!("{expr}")) + .collect::>() + .join(", ") + ) + } + Self::CounterHint { var: result } => { + format!("{result} = counter_hint()") + } + Self::RawAccess { res, index, shift } => { + format!("memory[{index} + {shift}] = {res}") + } + Self::IfNotZero { + condition, + then_branch, + else_branch, + } => { + let then_str = then_branch + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + + let else_str = else_branch + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + + if else_branch.is_empty() { + format!("if {condition} != 0 {{\n{then_str}\n{spaces}}}") + } else { + format!( + "if {condition} != 0 {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" + ) + } + } + Self::FunctionCall { + function_name, + args, + return_data, + } => { + let args_str = args + .iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", "); + let return_data_str = return_data + .iter() + .map(|var| var.to_string()) + .collect::>() + .join(", "); + + if return_data.is_empty() { + format!("{function_name}({args_str})") + } else { + format!("{return_data_str} = {function_name}({args_str})") + } + } + Self::FunctionRet { return_data } => { + let return_data_str = return_data + .iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", "); + format!("return {return_data_str}") + } + Self::Precompile { precompile, args } => { + format!( + "{}({})", + &precompile.name, + args.iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", ") + ) + } + Self::Print { + line_info: _, + content, + } => { + let content_str = content + .iter() + .map(|c| format!("{c}")) + .collect::>() + .join(", "); + format!("print({content_str})") + } + Self::HintMAlloc { + var, + size, + vectorized, + vectorized_len, + } => { + if *vectorized { + format!("{var} = malloc_vec({size}, {vectorized_len})") + } else { + format!("{var} = malloc({size})") + } + } + Self::ConstMalloc { + var, + size, + label: _, + } => { + format!("{var} = malloc({size})") + } + Self::Panic => "panic".to_string(), + Self::LocationReport { .. } => Default::default(), + }; + format!("{spaces}{line_str}") + } +} + +impl Display for SimpleFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let args_str = self + .arguments + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", "); + + let instructions_str = self + .instructions + .iter() + .map(|line| line.to_string_with_indent(1)) + .collect::>() + .join("\n"); + + if self.instructions.is_empty() { + write!( + f, + "fn {}({}) -> {} {{}}", + self.name, args_str, self.n_returned_vars + ) + } else { + write!( + f, + "fn {}({}) -> {} {{\n{}\n}}", + self.name, args_str, self.n_returned_vars, instructions_str + ) + } + } +} + +impl Display for SimpleProgram { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for function in self.functions.values() { + if !first { + writeln!(f)?; + } + write!(f, "{function}")?; + first = false; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{ConstExpression, SimpleExpr}; + use crate::precompiles::PRECOMPILES; + + #[test] + fn test_var_or_const_malloc_access_from_var() { + let var = "test_var".to_string(); + let access = VarOrConstMallocAccess::from(var.clone()); + + assert_eq!(access, VarOrConstMallocAccess::Var(var)); + } + + #[test] + fn test_var_or_const_malloc_access_from_to_simple_expr() { + let var = "test_var".to_string(); + let access = VarOrConstMallocAccess::Var(var.clone()); + let simple_expr: SimpleExpr = access.into(); + + assert_eq!(simple_expr, SimpleExpr::Var(var)); + } + + #[test] + fn test_var_or_const_malloc_access_const_malloc_conversion() { + let label = 42; + let offset = ConstExpression::from(100); + let access = VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: label, + offset: offset.clone(), + }; + let simple_expr: SimpleExpr = access.into(); + + assert_eq!( + simple_expr, + SimpleExpr::ConstMallocAccess { + malloc_label: label, + offset, + } + ); + } + + #[test] + fn test_simple_expr_try_into_var_or_const_malloc_access_var() { + let var = "test_var".to_string(); + let simple_expr = SimpleExpr::Var(var.clone()); + let result: Result = simple_expr.try_into(); + + assert_eq!(result, Ok(VarOrConstMallocAccess::Var(var))); + } + + #[test] + fn test_simple_expr_try_into_var_or_const_malloc_access_const_malloc() { + let label = 42; + let offset = ConstExpression::from(100); + let simple_expr = SimpleExpr::ConstMallocAccess { + malloc_label: label, + offset: offset.clone(), + }; + let result: Result = simple_expr.try_into(); + + assert_eq!( + result, + Ok(VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: label, + offset, + }) + ); + } + + #[test] + fn test_simple_expr_try_into_var_or_const_malloc_access_failure() { + let simple_expr = SimpleExpr::Constant(ConstExpression::from(42)); + let result: Result = simple_expr.try_into(); + + assert_eq!(result, Err(())); + } + + #[test] + fn test_array_manager_get_aux_var_first_time() { + let mut manager = ArrayManager::default(); + let array = SimpleExpr::Var("test_array".to_string()); + let index = crate::lang::Expression::scalar(10); + + let var = manager.get_aux_var(&array, &index); + + assert_eq!(var, "@arr_aux_0"); + assert_eq!(manager.counter, 1); + assert!( + manager + .aux_vars + .contains_key(&(array.clone(), index.clone())) + ); + } + + #[test] + fn test_array_manager_get_aux_var_repeated_access() { + let mut manager = ArrayManager::default(); + let array = SimpleExpr::Var("test_array".to_string()); + let index = crate::lang::Expression::scalar(10); + + let var1 = manager.get_aux_var(&array, &index); + let var2 = manager.get_aux_var(&array, &index); + + assert_eq!(var1, var2); + assert_eq!(var1, "@arr_aux_0"); + assert_eq!(manager.counter, 1); // Should not increment for repeated access + } + + #[test] + fn test_array_manager_different_array_index_pairs() { + let mut manager = ArrayManager::default(); + let array1 = SimpleExpr::Var("array1".to_string()); + let array2 = SimpleExpr::Var("array2".to_string()); + let index1 = crate::lang::Expression::scalar(10); + let index2 = crate::lang::Expression::scalar(20); + + let var1 = manager.get_aux_var(&array1, &index1); + let var2 = manager.get_aux_var(&array2, &index2); + let var3 = manager.get_aux_var(&array1, &index2); + + assert_eq!(var1, "@arr_aux_0"); + assert_eq!(var2, "@arr_aux_1"); + assert_eq!(var3, "@arr_aux_2"); + assert_eq!(manager.counter, 3); + } + + #[test] + fn test_counters_default() { + let counters = Counters::default(); + + assert_eq!(counters.aux_vars, 0); + assert_eq!(counters.loops, 0); + assert_eq!(counters.unrolls, 0); + } + + #[test] + fn test_const_malloc_default() { + let const_malloc = ConstMalloc::default(); + + assert_eq!(const_malloc.counter, 0); + assert!(const_malloc.map.is_empty()); + assert!(const_malloc.forbidden_vars.is_empty()); + } + + #[test] + fn test_var_or_const_malloc_access_display_var() { + let access = VarOrConstMallocAccess::Var("test_var".to_string()); + + assert_eq!(format!("{}", access), "test_var"); + } + + #[test] + fn test_var_or_const_malloc_access_display_const_malloc() { + let access = VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: 42, + offset: ConstExpression::from(100), + }; + + assert_eq!(format!("{}", access), "ConstMallocAccess(42, 100)"); + } + + #[test] + fn test_simple_line_display_assignment() { + let line = SimpleLine::Assignment { + var: VarOrConstMallocAccess::Var("x".to_string()), + operation: crate::ir::HighLevelOperation::Add, + arg0: SimpleExpr::Var("a".to_string()), + arg1: SimpleExpr::Var("b".to_string()), + }; + + assert_eq!(format!("{}", line), "x = a + b"); + } + + #[test] + fn test_simple_line_display_panic() { + let line = SimpleLine::Panic; + + assert_eq!(format!("{}", line), "panic"); + } + + #[test] + fn test_simple_line_display_function_call_no_return() { + let line = SimpleLine::FunctionCall { + function_name: "test_func".to_string(), + args: vec![SimpleExpr::scalar(42), SimpleExpr::Var("x".to_string())], + return_data: vec![], + }; + + assert_eq!(format!("{}", line), "test_func(42, x)"); + } + + #[test] + fn test_simple_line_display_function_call_with_return() { + let line = SimpleLine::FunctionCall { + function_name: "test_func".to_string(), + args: vec![SimpleExpr::scalar(42)], + return_data: vec!["result".to_string()], + }; + + assert_eq!(format!("{}", line), "result = test_func(42)"); + } + + #[test] + fn test_simple_line_display_function_ret() { + let line = SimpleLine::FunctionRet { + return_data: vec![SimpleExpr::scalar(42), SimpleExpr::Var("x".to_string())], + }; + + assert_eq!(format!("{}", line), "return 42, x"); + } + + #[test] + fn test_simple_line_display_if_not_zero_no_else() { + let line = SimpleLine::IfNotZero { + condition: SimpleExpr::Var("x".to_string()), + then_branch: vec![SimpleLine::Panic], + else_branch: vec![], + }; + + let expected = "if x != 0 {\n panic\n}"; + assert_eq!(format!("{}", line), expected); + } + + #[test] + fn test_simple_line_display_if_not_zero_with_else() { + let line = SimpleLine::IfNotZero { + condition: SimpleExpr::Var("x".to_string()), + then_branch: vec![SimpleLine::Panic], + else_branch: vec![SimpleLine::FunctionRet { + return_data: vec![], + }], + }; + + let expected = "if x != 0 {\n panic\n} else {\n return \n}"; + assert_eq!(format!("{}", line), expected); + } + + #[test] + fn test_simple_line_display_counter_hint() { + let line = SimpleLine::CounterHint { + var: "counter".to_string(), + }; + + assert_eq!(format!("{}", line), "counter = counter_hint()"); + } + + #[test] + fn test_simple_line_display_decompose_bits() { + let line = SimpleLine::DecomposeBits { + var: "bits".to_string(), + to_decompose: vec![SimpleExpr::scalar(255), SimpleExpr::Var("x".to_string())], + label: 42, + }; + + assert_eq!(format!("{}", line), "bits = decompose_bits(255, x)"); + } + + #[test] + fn test_simple_line_display_raw_access() { + let line = SimpleLine::RawAccess { + res: SimpleExpr::Var("result".to_string()), + index: SimpleExpr::Var("ptr".to_string()), + shift: ConstExpression::from(10), + }; + + assert_eq!(format!("{}", line), "memory[ptr + 10] = result"); + } + + #[test] + fn test_simple_line_display_print() { + let line = SimpleLine::Print { + line_info: "debug".to_string(), + content: vec![SimpleExpr::scalar(42), SimpleExpr::Var("x".to_string())], + }; + + assert_eq!(format!("{}", line), "print(42, x)"); + } + + #[test] + fn test_simple_line_display_hint_malloc_non_vectorized() { + let line = SimpleLine::HintMAlloc { + var: "ptr".to_string(), + size: SimpleExpr::scalar(100), + vectorized: false, + vectorized_len: SimpleExpr::zero(), + }; + + assert_eq!(format!("{}", line), "ptr = malloc(100)"); + } + + #[test] + fn test_simple_line_display_hint_malloc_vectorized() { + let line = SimpleLine::HintMAlloc { + var: "ptr".to_string(), + size: SimpleExpr::scalar(100), + vectorized: true, + vectorized_len: SimpleExpr::scalar(8), + }; + + assert_eq!(format!("{}", line), "ptr = malloc_vec(100, 8)"); + } + + #[test] + fn test_simple_line_display_const_malloc() { + let line = SimpleLine::ConstMalloc { + var: "ptr".to_string(), + size: ConstExpression::from(100), + label: 42, + }; + + assert_eq!(format!("{}", line), "ptr = malloc(100)"); + } + + #[test] + fn test_simple_line_display_precompile() { + let precompile = PRECOMPILES[0].clone(); + let line = SimpleLine::Precompile { + precompile, + args: vec![SimpleExpr::scalar(42)], + }; + + assert_eq!(format!("{}", line), format!("{}(42)", PRECOMPILES[0].name)); + } + + #[test] + fn test_simple_line_display_location_report() { + let line = SimpleLine::LocationReport { location: 42 }; + + assert_eq!(format!("{}", line), ""); // LocationReport displays as empty + } + + #[test] + fn test_simple_line_display_match() { + let line = SimpleLine::Match { + value: SimpleExpr::Var("x".to_string()), + arms: vec![ + vec![SimpleLine::Panic], + vec![SimpleLine::FunctionRet { + return_data: vec![], + }], + ], + }; + + let expected = "match x {\n0 => panic, 1 => return \n}"; + assert_eq!(format!("{}", line), expected); + } + + #[test] + fn test_simple_function_display_empty() { + let function = SimpleFunction { + name: "test".to_string(), + arguments: vec!["x".to_string(), "y".to_string()], + n_returned_vars: 1, + instructions: vec![], + }; + + assert_eq!(format!("{}", function), "fn test(x, y) -> 1 {}"); + } + + #[test] + fn test_simple_function_display_with_body() { + let function = SimpleFunction { + name: "test".to_string(), + arguments: vec!["x".to_string()], + n_returned_vars: 1, + instructions: vec![SimpleLine::Panic], + }; + + assert_eq!(format!("{}", function), "fn test(x) -> 1 {\n panic\n}"); + } + + #[test] + fn test_simple_program_display_empty() { + let program = SimpleProgram { + functions: BTreeMap::new(), + }; + + assert_eq!(format!("{}", program), ""); + } + + #[test] + fn test_simple_program_display_single_function() { + let mut functions = BTreeMap::new(); + functions.insert( + "test".to_string(), + SimpleFunction { + name: "test".to_string(), + arguments: vec![], + n_returned_vars: 0, + instructions: vec![SimpleLine::Panic], + }, + ); + + let program = SimpleProgram { functions }; + + assert_eq!(format!("{}", program), "fn test() -> 0 {\n panic\n}"); + } + + #[test] + fn test_simple_program_display_multiple_functions() { + let mut functions = BTreeMap::new(); + functions.insert( + "func1".to_string(), + SimpleFunction { + name: "func1".to_string(), + arguments: vec![], + n_returned_vars: 0, + instructions: vec![], + }, + ); + functions.insert( + "func2".to_string(), + SimpleFunction { + name: "func2".to_string(), + arguments: vec![], + n_returned_vars: 0, + instructions: vec![], + }, + ); + + let program = SimpleProgram { functions }; + + assert_eq!( + format!("{}", program), + "fn func1() -> 0 {}\nfn func2() -> 0 {}" + ); + } + + #[test] + fn test_array_access_type_var_assigned() { + let access_type = ArrayAccessType::VarIsAssigned("result".to_string()); + + match access_type { + ArrayAccessType::VarIsAssigned(var) => assert_eq!(var, "result"), + _ => panic!("Expected VarIsAssigned"), + } + } + + #[test] + fn test_array_access_type_array_assigned() { + let expr = crate::lang::Expression::scalar(42); + let access_type = ArrayAccessType::ArrayIsAssigned(expr.clone()); + + match access_type { + ArrayAccessType::ArrayIsAssigned(e) => assert_eq!(e, expr), + _ => panic!("Expected ArrayIsAssigned"), + } + } +} diff --git a/crates/lean_compiler/src/simplify/unroll.rs b/crates/lean_compiler/src/simplify/unroll.rs new file mode 100644 index 00000000..3a3c74e0 --- /dev/null +++ b/crates/lean_compiler/src/simplify/unroll.rs @@ -0,0 +1,784 @@ +use crate::lang::{Boolean, ConstExpression, Expression, Line, SimpleExpr, Var}; +use std::collections::BTreeSet; + +/// Replace variables for unrolling in an expression. +pub fn replace_vars_for_unroll_in_expr( + expr: &mut Expression, + iterator: &Var, + unroll_index: usize, + iterator_value: usize, + internal_vars: &BTreeSet, +) { + match expr { + Expression::Value(value_expr) => match value_expr { + SimpleExpr::Var(var) => { + if var == iterator { + *value_expr = SimpleExpr::Constant(ConstExpression::from(iterator_value)); + } else if internal_vars.contains(var) { + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + } + } + SimpleExpr::Constant(_) | SimpleExpr::ConstMallocAccess { .. } => {} + }, + Expression::ArrayAccess { array, index } => { + if let SimpleExpr::Var(array_var) = array { + assert!(array_var != iterator, "Weird"); + if internal_vars.contains(array_var) { + *array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); + } + } + + replace_vars_for_unroll_in_expr( + index, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Expression::Binary { left, right, .. } => { + replace_vars_for_unroll_in_expr( + left, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + right, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Expression::Log2Ceil { value } => { + replace_vars_for_unroll_in_expr( + value, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } +} + +/// Replace variables for unrolling in a line sequence. +pub fn replace_vars_for_unroll( + lines: &mut [Line], + iterator: &Var, + unroll_index: usize, + iterator_value: usize, + internal_vars: &BTreeSet, +) { + for line in lines { + match line { + Line::Match { value, arms } => { + replace_vars_for_unroll_in_expr( + value, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + for (_, statements) in arms { + replace_vars_for_unroll( + statements, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } + Line::Assignment { var, value } => { + assert!(var != iterator, "Weird"); + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + replace_vars_for_unroll_in_expr( + value, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Line::ArrayAssign { + // array[index] = value + array, + index, + value, + } => { + if let SimpleExpr::Var(array_var) = array { + assert!(array_var != iterator, "Weird"); + if internal_vars.contains(array_var) { + *array_var = + format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}"); + } + } + replace_vars_for_unroll_in_expr( + index, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + value, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Line::Assert(Boolean::Equal { left, right } | Boolean::Different { left, right }) => { + replace_vars_for_unroll_in_expr( + left, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + right, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Line::IfCondition { + condition: Boolean::Equal { left, right } | Boolean::Different { left, right }, + then_branch, + else_branch, + } => { + replace_vars_for_unroll_in_expr( + left, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + right, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll( + then_branch, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll( + else_branch, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Line::ForLoop { + iterator: other_iterator, + start, + end, + body, + rev: _, + unroll: _, + } => { + assert!(other_iterator != iterator); + *other_iterator = + format!("@unrolled_{unroll_index}_{iterator_value}_{other_iterator}"); + replace_vars_for_unroll_in_expr( + start, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + end, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll( + body, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Line::FunctionCall { + function_name: _, + args, + return_data, + } => { + // Function calls are not unrolled, so we don't need to change them + for arg in args { + replace_vars_for_unroll_in_expr( + arg, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + for ret in return_data { + *ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}"); + } + } + Line::FunctionRet { return_data } => { + for ret in return_data { + replace_vars_for_unroll_in_expr( + ret, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } + Line::Precompile { + precompile: _, + args, + } => { + for arg in args { + replace_vars_for_unroll_in_expr( + arg, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } + Line::Print { line_info, content } => { + // Print statements are not unrolled, so we don't need to change them + *line_info += &format!(" (unrolled {unroll_index} {iterator_value})"); + for var in content { + replace_vars_for_unroll_in_expr( + var, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } + Line::MAlloc { + var, + size, + vectorized: _, + vectorized_len, + } => { + assert!(var != iterator, "Weird"); + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + replace_vars_for_unroll_in_expr( + size, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + vectorized_len, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Line::DecomposeBits { var, to_decompose } => { + assert!(var != iterator, "Weird"); + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + for expr in to_decompose { + replace_vars_for_unroll_in_expr( + expr, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } + Line::CounterHint { var } => { + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + } + Line::Break | Line::Panic | Line::LocationReport { .. } => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_replace_vars_for_unroll_in_expr_value_iterator_replacement() { + let mut expr = Expression::Value(SimpleExpr::Var("i".to_string())); + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + replace_vars_for_unroll_in_expr(&mut expr, &iterator, 0, 5, &internal_vars); + + assert_eq!(expr, Expression::scalar(5)); + } + + #[test] + fn test_replace_vars_for_unroll_in_expr_value_internal_var() { + let mut expr = Expression::Value(SimpleExpr::Var("x".to_string())); + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("x".to_string()); + + replace_vars_for_unroll_in_expr(&mut expr, &iterator, 1, 3, &internal_vars); + + assert_eq!( + expr, + Expression::Value(SimpleExpr::Var("@unrolled_1_3_x".to_string())) + ); + } + + #[test] + fn test_replace_vars_for_unroll_in_expr_value_external_var() { + let mut expr = Expression::Value(SimpleExpr::Var("y".to_string())); + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + replace_vars_for_unroll_in_expr(&mut expr, &iterator, 0, 2, &internal_vars); + + // External variables should not be modified + assert_eq!(expr, Expression::Value(SimpleExpr::Var("y".to_string()))); + } + + #[test] + fn test_replace_vars_for_unroll_in_expr_array_access() { + let mut expr = Expression::ArrayAccess { + array: SimpleExpr::Var("arr".to_string()), + index: Box::new(Expression::Value(SimpleExpr::Var("i".to_string()))), + }; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("arr".to_string()); + + replace_vars_for_unroll_in_expr(&mut expr, &iterator, 2, 7, &internal_vars); + + assert_eq!( + expr, + Expression::ArrayAccess { + array: SimpleExpr::Var("@unrolled_2_7_arr".to_string()), + index: Box::new(Expression::scalar(7)), + } + ); + } + + #[test] + fn test_replace_vars_for_unroll_in_expr_binary() { + let mut expr = Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("i".to_string()))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + }; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("x".to_string()); + + replace_vars_for_unroll_in_expr(&mut expr, &iterator, 0, 10, &internal_vars); + + assert_eq!( + expr, + Expression::Binary { + left: Box::new(Expression::scalar(10)), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var( + "@unrolled_0_10_x".to_string() + ))), + } + ); + } + + #[test] + fn test_replace_vars_for_unroll_in_expr_log2_ceil() { + let mut expr = Expression::Log2Ceil { + value: Box::new(Expression::Value(SimpleExpr::Var("i".to_string()))), + }; + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + replace_vars_for_unroll_in_expr(&mut expr, &iterator, 3, 16, &internal_vars); + + assert_eq!( + expr, + Expression::Log2Ceil { + value: Box::new(Expression::scalar(16)), + } + ); + } + + #[test] + fn test_replace_vars_for_unroll_assignment() { + let mut lines = vec![Line::Assignment { + var: "sum".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("sum".to_string()))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("i".to_string()))), + }, + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("sum".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 1, 5, &internal_vars); + + assert_eq!( + lines, + vec![Line::Assignment { + var: "@unrolled_1_5_sum".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var( + "@unrolled_1_5_sum".to_string() + ))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::scalar(5)), + }, + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_function_call() { + let mut lines = vec![Line::FunctionCall { + function_name: "test_func".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("i".to_string()))], + return_data: vec!["result".to_string()], + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("result".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 2, 8, &internal_vars); + + assert_eq!( + lines, + vec![Line::FunctionCall { + function_name: "test_func".to_string(), + args: vec![Expression::scalar(8)], + return_data: vec!["@unrolled_2_8_result".to_string()], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_if_condition() { + let mut lines = vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("i".to_string())), + right: Expression::scalar(5), + }, + then_branch: vec![Line::Assignment { + var: "x".to_string(), + value: Expression::scalar(1), + }], + else_branch: vec![], + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("x".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 0, 3, &internal_vars); + + assert_eq!( + lines, + vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::scalar(3), + right: Expression::scalar(5), + }, + then_branch: vec![Line::Assignment { + var: "@unrolled_0_3_x".to_string(), + value: Expression::scalar(1), + }], + else_branch: vec![], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_for_loop() { + let mut lines = vec![Line::ForLoop { + iterator: "j".to_string(), + start: Expression::Value(SimpleExpr::Var("i".to_string())), + end: Expression::scalar(10), + body: vec![Line::Assignment { + var: "total".to_string(), + value: Expression::Value(SimpleExpr::Var("j".to_string())), + }], + rev: false, + unroll: false, + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("j".to_string()); + internal_vars.insert("total".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 1, 7, &internal_vars); + + assert_eq!( + lines, + vec![Line::ForLoop { + iterator: "@unrolled_1_7_j".to_string(), + start: Expression::scalar(7), + end: Expression::scalar(10), + body: vec![Line::Assignment { + var: "@unrolled_1_7_total".to_string(), + value: Expression::Value(SimpleExpr::Var("@unrolled_1_7_j".to_string())), + }], + rev: false, + unroll: false, + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_match() { + let mut lines = vec![Line::Match { + value: Expression::Value(SimpleExpr::Var("i".to_string())), + arms: vec![ + ( + 0, + vec![Line::Assignment { + var: "a".to_string(), + value: Expression::scalar(1), + }], + ), + ( + 1, + vec![Line::Assignment { + var: "b".to_string(), + value: Expression::scalar(2), + }], + ), + ], + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("a".to_string()); + internal_vars.insert("b".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 3, 4, &internal_vars); + + assert_eq!( + lines, + vec![Line::Match { + value: Expression::scalar(4), + arms: vec![ + ( + 0, + vec![Line::Assignment { + var: "@unrolled_3_4_a".to_string(), + value: Expression::scalar(1), + }] + ), + ( + 1, + vec![Line::Assignment { + var: "@unrolled_3_4_b".to_string(), + value: Expression::scalar(2), + }] + ), + ], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_assert() { + let mut lines = vec![Line::Assert(Boolean::Different { + left: Expression::Value(SimpleExpr::Var("i".to_string())), + right: Expression::Value(SimpleExpr::Var("x".to_string())), + })]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("x".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 0, 6, &internal_vars); + + assert_eq!( + lines, + vec![Line::Assert(Boolean::Different { + left: Expression::scalar(6), + right: Expression::Value(SimpleExpr::Var("@unrolled_0_6_x".to_string())), + })] + ); + } + + #[test] + fn test_replace_vars_for_unroll_malloc() { + let mut lines = vec![Line::MAlloc { + var: "ptr".to_string(), + size: Expression::Value(SimpleExpr::Var("i".to_string())), + vectorized: false, + vectorized_len: Expression::scalar(1), + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("ptr".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 1, 64, &internal_vars); + + assert_eq!( + lines, + vec![Line::MAlloc { + var: "@unrolled_1_64_ptr".to_string(), + size: Expression::scalar(64), + vectorized: false, + vectorized_len: Expression::scalar(1), + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_decompose_bits() { + let mut lines = vec![Line::DecomposeBits { + var: "bits".to_string(), + to_decompose: vec![Expression::Value(SimpleExpr::Var("i".to_string()))], + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("bits".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 2, 255, &internal_vars); + + assert_eq!( + lines, + vec![Line::DecomposeBits { + var: "@unrolled_2_255_bits".to_string(), + to_decompose: vec![Expression::scalar(255)], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_counter_hint() { + let mut lines = vec![Line::CounterHint { + var: "counter".to_string(), + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("counter".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 0, 1, &internal_vars); + + assert_eq!( + lines, + vec![Line::CounterHint { + var: "@unrolled_0_1_counter".to_string(), + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_print() { + let mut lines = vec![Line::Print { + line_info: "debug".to_string(), + content: vec![Expression::Value(SimpleExpr::Var("i".to_string()))], + }]; + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + replace_vars_for_unroll(&mut lines, &iterator, 5, 42, &internal_vars); + + assert_eq!( + lines, + vec![Line::Print { + line_info: "debug (unrolled 5 42)".to_string(), + content: vec![Expression::scalar(42)], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_array_assign() { + let mut lines = vec![Line::ArrayAssign { + array: SimpleExpr::Var("arr".to_string()), + index: Expression::Value(SimpleExpr::Var("i".to_string())), + value: Expression::Value(SimpleExpr::Var("val".to_string())), + }]; + let iterator = "i".to_string(); + let mut internal_vars = BTreeSet::new(); + internal_vars.insert("arr".to_string()); + internal_vars.insert("val".to_string()); + + replace_vars_for_unroll(&mut lines, &iterator, 1, 12, &internal_vars); + + assert_eq!( + lines, + vec![Line::ArrayAssign { + array: SimpleExpr::Var("@unrolled_1_12_arr".to_string()), + index: Expression::scalar(12), + value: Expression::Value(SimpleExpr::Var("@unrolled_1_12_val".to_string())), + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_function_ret() { + let mut lines = vec![Line::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Var("i".to_string()))], + }]; + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + replace_vars_for_unroll(&mut lines, &iterator, 0, 100, &internal_vars); + + assert_eq!( + lines, + vec![Line::FunctionRet { + return_data: vec![Expression::scalar(100)], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_precompile() { + let mut lines = vec![Line::Precompile { + precompile: crate::precompiles::PRECOMPILES[0].clone(), + args: vec![Expression::Value(SimpleExpr::Var("i".to_string()))], + }]; + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + replace_vars_for_unroll(&mut lines, &iterator, 3, 25, &internal_vars); + + assert_eq!( + lines, + vec![Line::Precompile { + precompile: crate::precompiles::PRECOMPILES[0].clone(), + args: vec![Expression::scalar(25)], + }] + ); + } + + #[test] + fn test_replace_vars_for_unroll_no_op_lines() { + let mut lines = vec![ + Line::Break, + Line::Panic, + Line::LocationReport { location: 42 }, + ]; + let iterator = "i".to_string(); + let internal_vars = BTreeSet::new(); + + let expected = lines.clone(); + replace_vars_for_unroll(&mut lines, &iterator, 0, 1, &internal_vars); + + assert_eq!(lines, expected); // Should remain unchanged + } +} diff --git a/crates/lean_compiler/src/simplify/utilities.rs b/crates/lean_compiler/src/simplify/utilities.rs new file mode 100644 index 00000000..4a93c789 --- /dev/null +++ b/crates/lean_compiler/src/simplify/utilities.rs @@ -0,0 +1,927 @@ +use crate::{ + F, + lang::{Boolean, Expression, Line, SimpleExpr, Var}, +}; +use std::collections::{BTreeMap, BTreeSet}; +use utils::ToUsize; + +/// Returns (internal_vars, external_vars) for a sequence of lines. +pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { + let mut internal_vars = BTreeSet::new(); + let mut external_vars = BTreeSet::new(); + + let on_new_expr = + |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { + for var in vars_in_expression(expr) { + if !internal_vars.contains(&var) { + external_vars.insert(var); + } + } + }; + + let on_new_condition = + |condition: &Boolean, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { + let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; + on_new_expr(left, internal_vars, external_vars); + on_new_expr(right, internal_vars, external_vars); + }; + + for line in lines { + match line { + Line::Match { value, arms } => { + on_new_expr(value, &internal_vars, &mut external_vars); + for (_, statements) in arms { + let (stmt_internal, stmt_external) = find_variable_usage(statements); + internal_vars.extend(stmt_internal); + external_vars.extend( + stmt_external + .into_iter() + .filter(|v| !internal_vars.contains(v)), + ); + } + } + Line::Assignment { var, value } => { + on_new_expr(value, &internal_vars, &mut external_vars); + internal_vars.insert(var.clone()); + } + Line::IfCondition { + condition, + then_branch, + else_branch, + } => { + on_new_condition(condition, &internal_vars, &mut external_vars); + + let (then_internal, then_external) = find_variable_usage(then_branch); + let (else_internal, else_external) = find_variable_usage(else_branch); + + internal_vars.extend(then_internal.union(&else_internal).cloned()); + external_vars.extend( + then_external + .union(&else_external) + .filter(|v| !internal_vars.contains(*v)) + .cloned(), + ); + } + Line::FunctionCall { + args, return_data, .. + } => { + for arg in args { + on_new_expr(arg, &internal_vars, &mut external_vars); + } + internal_vars.extend(return_data.iter().cloned()); + } + Line::Assert(condition) => { + on_new_condition(condition, &internal_vars, &mut external_vars); + } + Line::FunctionRet { return_data } => { + for ret in return_data { + on_new_expr(ret, &internal_vars, &mut external_vars); + } + } + Line::MAlloc { var, size, .. } => { + on_new_expr(size, &internal_vars, &mut external_vars); + internal_vars.insert(var.clone()); + } + Line::Precompile { + precompile: _, + args, + } => { + for arg in args { + on_new_expr(arg, &internal_vars, &mut external_vars); + } + } + Line::Print { content, .. } => { + for var in content { + on_new_expr(var, &internal_vars, &mut external_vars); + } + } + Line::DecomposeBits { var, to_decompose } => { + for expr in to_decompose { + on_new_expr(expr, &internal_vars, &mut external_vars); + } + internal_vars.insert(var.clone()); + } + Line::CounterHint { var } => { + internal_vars.insert(var.clone()); + } + Line::ForLoop { + iterator, + start, + end, + body, + rev: _, + unroll: _, + } => { + let (body_internal, body_external) = find_variable_usage(body); + internal_vars.extend(body_internal); + internal_vars.insert(iterator.clone()); + external_vars.extend(body_external.difference(&internal_vars).cloned()); + on_new_expr(start, &internal_vars, &mut external_vars); + on_new_expr(end, &internal_vars, &mut external_vars); + } + Line::ArrayAssign { + array, + index, + value, + } => { + on_new_expr(&array.clone().into(), &internal_vars, &mut external_vars); + on_new_expr(index, &internal_vars, &mut external_vars); + on_new_expr(value, &internal_vars, &mut external_vars); + } + Line::Panic | Line::Break | Line::LocationReport { .. } => {} + } + } + + (internal_vars, external_vars) +} + +/// Extract all variables referenced in an expression. +pub fn vars_in_expression(expr: &Expression) -> BTreeSet { + let mut vars = BTreeSet::new(); + match expr { + Expression::Value(value) => { + if let SimpleExpr::Var(var) = value { + vars.insert(var.clone()); + } + } + Expression::ArrayAccess { array, index } => { + if let SimpleExpr::Var(array) = array { + vars.insert(array.clone()); + } + vars.extend(vars_in_expression(index)); + } + Expression::Binary { left, right, .. } => { + vars.extend(vars_in_expression(left)); + vars.extend(vars_in_expression(right)); + } + Expression::Log2Ceil { value } => { + vars.extend(vars_in_expression(value)); + } + } + vars +} + +/// Replace variables with constants in an expression. +pub fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) { + match expr { + Expression::Value(value) => match &value { + SimpleExpr::Var(var) => { + if let Some(const_value) = map.get(var) { + *value = SimpleExpr::scalar(const_value.to_usize()); + } + } + SimpleExpr::ConstMallocAccess { .. } => { + unreachable!() + } + SimpleExpr::Constant(_) => {} + }, + Expression::ArrayAccess { array, index } => { + if let SimpleExpr::Var(array_var) = array { + assert!( + !map.contains_key(array_var), + "Array {array_var} is a constant" + ); + } + replace_vars_by_const_in_expr(index, map); + } + Expression::Binary { left, right, .. } => { + replace_vars_by_const_in_expr(left, map); + replace_vars_by_const_in_expr(right, map); + } + Expression::Log2Ceil { value } => { + replace_vars_by_const_in_expr(value, map); + } + } +} + +/// Replace variables with constants in a line sequence. +pub fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { + for line in lines { + match line { + Line::Match { value, arms } => { + replace_vars_by_const_in_expr(value, map); + for (_, statements) in arms { + replace_vars_by_const_in_lines(statements, map); + } + } + Line::Assignment { var, value } => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); + replace_vars_by_const_in_expr(value, map); + } + Line::ArrayAssign { + array, + index, + value, + } => { + if let SimpleExpr::Var(array_var) = array { + assert!( + !map.contains_key(array_var), + "Array {array_var} is a constant" + ); + } + replace_vars_by_const_in_expr(index, map); + replace_vars_by_const_in_expr(value, map); + } + Line::FunctionCall { + args, return_data, .. + } => { + for arg in args { + replace_vars_by_const_in_expr(arg, map); + } + for ret in return_data { + assert!( + !map.contains_key(ret), + "Return variable {ret} is a constant" + ); + } + } + Line::IfCondition { + condition, + then_branch, + else_branch, + } => { + match condition { + Boolean::Equal { left, right } | Boolean::Different { left, right } => { + replace_vars_by_const_in_expr(left, map); + replace_vars_by_const_in_expr(right, map); + } + } + replace_vars_by_const_in_lines(then_branch, map); + replace_vars_by_const_in_lines(else_branch, map); + } + Line::ForLoop { + body, start, end, .. + } => { + replace_vars_by_const_in_expr(start, map); + replace_vars_by_const_in_expr(end, map); + replace_vars_by_const_in_lines(body, map); + } + Line::Assert(condition) => match condition { + Boolean::Equal { left, right } | Boolean::Different { left, right } => { + replace_vars_by_const_in_expr(left, map); + replace_vars_by_const_in_expr(right, map); + } + }, + Line::FunctionRet { return_data } => { + for ret in return_data { + replace_vars_by_const_in_expr(ret, map); + } + } + Line::Precompile { + precompile: _, + args, + } => { + for arg in args { + replace_vars_by_const_in_expr(arg, map); + } + } + Line::Print { content, .. } => { + for var in content { + replace_vars_by_const_in_expr(var, map); + } + } + Line::DecomposeBits { var, to_decompose } => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); + for expr in to_decompose { + replace_vars_by_const_in_expr(expr, map); + } + } + Line::CounterHint { var } => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); + } + Line::MAlloc { var, size, .. } => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); + replace_vars_by_const_in_expr(size, map); + } + Line::Panic | Line::Break | Line::LocationReport { .. } => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{Boolean, Expression, Line, SimpleExpr}; + use p3_field::PrimeCharacteristicRing; + use std::collections::{BTreeMap, BTreeSet}; + + #[test] + fn test_find_variable_usage_empty() { + let lines: Vec = vec![]; + let (internal, external) = find_variable_usage(&lines); + + assert!(internal.is_empty()); + assert!(external.is_empty()); + } + + #[test] + fn test_find_variable_usage_assignment() { + let lines = vec![Line::Assignment { + var: "x".to_string(), + value: Expression::scalar(42), + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["x".to_string()])); + assert!(external.is_empty()); + } + + #[test] + fn test_find_variable_usage_assignment_with_external_var() { + let lines = vec![Line::Assignment { + var: "x".to_string(), + value: Expression::Value(SimpleExpr::Var("y".to_string())), + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["x".to_string()])); + assert_eq!(external, BTreeSet::from(["y".to_string()])); + } + + #[test] + fn test_find_variable_usage_function_call() { + let lines = vec![Line::FunctionCall { + function_name: "test".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("input".to_string()))], + return_data: vec!["output".to_string()], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["output".to_string()])); + assert_eq!(external, BTreeSet::from(["input".to_string()])); + } + + #[test] + fn test_find_variable_usage_if_condition() { + let lines = vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("a".to_string())), + right: Expression::scalar(10), + }, + then_branch: vec![Line::Assignment { + var: "b".to_string(), + value: Expression::scalar(1), + }], + else_branch: vec![Line::Assignment { + var: "c".to_string(), + value: Expression::scalar(2), + }], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["b".to_string(), "c".to_string()])); + assert_eq!(external, BTreeSet::from(["a".to_string()])); + } + + #[test] + fn test_find_variable_usage_for_loop() { + let lines = vec![Line::ForLoop { + iterator: "i".to_string(), + start: Expression::scalar(0), + end: Expression::Value(SimpleExpr::Var("n".to_string())), + body: vec![Line::Assignment { + var: "sum".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("sum".to_string()))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("i".to_string()))), + }, + }], + rev: false, + unroll: false, + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!( + internal, + BTreeSet::from(["i".to_string(), "sum".to_string()]) + ); + assert_eq!(external, BTreeSet::from(["n".to_string()])); + } + + #[test] + fn test_find_variable_usage_match() { + let lines = vec![Line::Match { + value: Expression::Value(SimpleExpr::Var("x".to_string())), + arms: vec![ + ( + 0, + vec![Line::Assignment { + var: "a".to_string(), + value: Expression::scalar(1), + }], + ), + ( + 1, + vec![Line::Assignment { + var: "b".to_string(), + value: Expression::scalar(2), + }], + ), + ], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["a".to_string(), "b".to_string()])); + assert_eq!(external, BTreeSet::from(["x".to_string()])); + } + + #[test] + fn test_find_variable_usage_malloc() { + let lines = vec![Line::MAlloc { + var: "ptr".to_string(), + size: Expression::Value(SimpleExpr::Var("size".to_string())), + vectorized: false, + vectorized_len: Expression::zero(), + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["ptr".to_string()])); + assert_eq!(external, BTreeSet::from(["size".to_string()])); + } + + #[test] + fn test_find_variable_usage_decompose_bits() { + let lines = vec![Line::DecomposeBits { + var: "bits".to_string(), + to_decompose: vec![Expression::Value(SimpleExpr::Var("value".to_string()))], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["bits".to_string()])); + assert_eq!(external, BTreeSet::from(["value".to_string()])); + } + + #[test] + fn test_find_variable_usage_counter_hint() { + let lines = vec![Line::CounterHint { + var: "counter".to_string(), + }]; + let (internal, external) = find_variable_usage(&lines); + + assert_eq!(internal, BTreeSet::from(["counter".to_string()])); + assert!(external.is_empty()); + } + + #[test] + fn test_find_variable_usage_assert() { + let lines = vec![Line::Assert(Boolean::Different { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::Value(SimpleExpr::Var("y".to_string())), + })]; + let (internal, external) = find_variable_usage(&lines); + + assert!(internal.is_empty()); + assert_eq!(external, BTreeSet::from(["x".to_string(), "y".to_string()])); + } + + #[test] + fn test_find_variable_usage_function_ret() { + let lines = vec![Line::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Var("result".to_string()))], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert!(internal.is_empty()); + assert_eq!(external, BTreeSet::from(["result".to_string()])); + } + + #[test] + fn test_find_variable_usage_precompile() { + let lines = vec![Line::Precompile { + precompile: crate::precompiles::PRECOMPILES[0].clone(), + args: vec![Expression::Value(SimpleExpr::Var("input".to_string()))], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert!(internal.is_empty()); + assert_eq!(external, BTreeSet::from(["input".to_string()])); + } + + #[test] + fn test_find_variable_usage_print() { + let lines = vec![Line::Print { + line_info: "debug".to_string(), + content: vec![Expression::Value(SimpleExpr::Var("debug_var".to_string()))], + }]; + let (internal, external) = find_variable_usage(&lines); + + assert!(internal.is_empty()); + assert_eq!(external, BTreeSet::from(["debug_var".to_string()])); + } + + #[test] + fn test_find_variable_usage_array_assign() { + let lines = vec![Line::ArrayAssign { + array: SimpleExpr::Var("arr".to_string()), + index: Expression::Value(SimpleExpr::Var("idx".to_string())), + value: Expression::Value(SimpleExpr::Var("val".to_string())), + }]; + let (internal, external) = find_variable_usage(&lines); + + assert!(internal.is_empty()); + assert_eq!( + external, + BTreeSet::from(["arr".to_string(), "idx".to_string(), "val".to_string()]) + ); + } + + #[test] + fn test_vars_in_expression_value() { + let expr = Expression::Value(SimpleExpr::Var("x".to_string())); + let vars = vars_in_expression(&expr); + + assert_eq!(vars, BTreeSet::from(["x".to_string()])); + } + + #[test] + fn test_vars_in_expression_constant() { + let expr = Expression::scalar(42); + let vars = vars_in_expression(&expr); + + assert!(vars.is_empty()); + } + + #[test] + fn test_vars_in_expression_binary() { + let expr = Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("a".to_string()))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("b".to_string()))), + }; + let vars = vars_in_expression(&expr); + + assert_eq!(vars, BTreeSet::from(["a".to_string(), "b".to_string()])); + } + + #[test] + fn test_vars_in_expression_array_access() { + let expr = Expression::ArrayAccess { + array: SimpleExpr::Var("arr".to_string()), + index: Box::new(Expression::Value(SimpleExpr::Var("idx".to_string()))), + }; + let vars = vars_in_expression(&expr); + + assert_eq!(vars, BTreeSet::from(["arr".to_string(), "idx".to_string()])); + } + + #[test] + fn test_vars_in_expression_log2_ceil() { + let expr = Expression::Log2Ceil { + value: Box::new(Expression::Value(SimpleExpr::Var("n".to_string()))), + }; + let vars = vars_in_expression(&expr); + + assert_eq!(vars, BTreeSet::from(["n".to_string()])); + } + + #[test] + fn test_vars_in_expression_nested() { + let expr = Expression::Binary { + left: Box::new(Expression::ArrayAccess { + array: SimpleExpr::Var("arr".to_string()), + index: Box::new(Expression::Value(SimpleExpr::Var("i".to_string()))), + }), + operation: crate::ir::HighLevelOperation::Mul, + right: Box::new(Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("y".to_string()))), + }), + }; + let vars = vars_in_expression(&expr); + + assert_eq!( + vars, + BTreeSet::from([ + "arr".to_string(), + "i".to_string(), + "x".to_string(), + "y".to_string() + ]) + ); + } + + #[test] + fn test_replace_vars_by_const_in_expr_var_replacement() { + let mut expr = Expression::Value(SimpleExpr::Var("x".to_string())); + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(42)); + + replace_vars_by_const_in_expr(&mut expr, &map); + + assert_eq!(expr, Expression::scalar(42)); + } + + #[test] + fn test_replace_vars_by_const_in_expr_no_replacement() { + let mut expr = Expression::Value(SimpleExpr::Var("y".to_string())); + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(42)); + + replace_vars_by_const_in_expr(&mut expr, &map); + + assert_eq!(expr, Expression::Value(SimpleExpr::Var("y".to_string()))); + } + + #[test] + fn test_replace_vars_by_const_in_expr_constant_unchanged() { + let mut expr = Expression::scalar(100); + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(42)); + + replace_vars_by_const_in_expr(&mut expr, &map); + + assert_eq!(expr, Expression::scalar(100)); + } + + #[test] + fn test_replace_vars_by_const_in_expr_binary() { + let mut expr = Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("y".to_string()))), + }; + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(10)); + map.insert("y".to_string(), crate::F::from_usize(20)); + + replace_vars_by_const_in_expr(&mut expr, &map); + + assert_eq!( + expr, + Expression::Binary { + left: Box::new(Expression::scalar(10)), + operation: crate::ir::HighLevelOperation::Add, + right: Box::new(Expression::scalar(20)), + } + ); + } + + #[test] + fn test_replace_vars_by_const_in_expr_log2_ceil() { + let mut expr = Expression::Log2Ceil { + value: Box::new(Expression::Value(SimpleExpr::Var("n".to_string()))), + }; + let mut map = BTreeMap::new(); + map.insert("n".to_string(), crate::F::from_usize(16)); + + replace_vars_by_const_in_expr(&mut expr, &map); + + assert_eq!( + expr, + Expression::Log2Ceil { + value: Box::new(Expression::scalar(16)), + } + ); + } + + #[test] + fn test_get_function_called_empty() { + let lines: Vec = vec![]; + let mut function_called = Vec::new(); + + get_function_called(&lines, &mut function_called); + + assert!(function_called.is_empty()); + } + + #[test] + fn test_get_function_called_function_call() { + let lines = vec![Line::FunctionCall { + function_name: "test_func".to_string(), + args: vec![], + return_data: vec![], + }]; + let mut function_called = Vec::new(); + + get_function_called(&lines, &mut function_called); + + assert_eq!(function_called, vec!["test_func".to_string()]); + } + + #[test] + fn test_get_function_called_multiple_calls() { + let lines = vec![ + Line::FunctionCall { + function_name: "func1".to_string(), + args: vec![], + return_data: vec![], + }, + Line::Assignment { + var: "x".to_string(), + value: Expression::scalar(42), + }, + Line::FunctionCall { + function_name: "func2".to_string(), + args: vec![], + return_data: vec![], + }, + ]; + let mut function_called = Vec::new(); + + get_function_called(&lines, &mut function_called); + + assert_eq!( + function_called, + vec!["func1".to_string(), "func2".to_string()] + ); + } + + #[test] + fn test_get_function_called_if_condition() { + let lines = vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::scalar(1), + right: Expression::scalar(1), + }, + then_branch: vec![Line::FunctionCall { + function_name: "then_func".to_string(), + args: vec![], + return_data: vec![], + }], + else_branch: vec![Line::FunctionCall { + function_name: "else_func".to_string(), + args: vec![], + return_data: vec![], + }], + }]; + let mut function_called = Vec::new(); + + get_function_called(&lines, &mut function_called); + + assert_eq!( + function_called, + vec!["then_func".to_string(), "else_func".to_string()] + ); + } + + #[test] + fn test_get_function_called_for_loop() { + let lines = vec![Line::ForLoop { + iterator: "i".to_string(), + start: Expression::scalar(0), + end: Expression::scalar(10), + body: vec![Line::FunctionCall { + function_name: "loop_func".to_string(), + args: vec![], + return_data: vec![], + }], + rev: false, + unroll: false, + }]; + let mut function_called = Vec::new(); + + get_function_called(&lines, &mut function_called); + + assert_eq!(function_called, vec!["loop_func".to_string()]); + } + + #[test] + fn test_get_function_called_match() { + let lines = vec![Line::Match { + value: Expression::scalar(1), + arms: vec![ + ( + 0, + vec![Line::FunctionCall { + function_name: "arm0_func".to_string(), + args: vec![], + return_data: vec![], + }], + ), + ( + 1, + vec![Line::FunctionCall { + function_name: "arm1_func".to_string(), + args: vec![], + return_data: vec![], + }], + ), + ], + }]; + let mut function_called = Vec::new(); + + get_function_called(&lines, &mut function_called); + + assert_eq!( + function_called, + vec!["arm0_func".to_string(), "arm1_func".to_string()] + ); + } + + #[test] + fn test_replace_vars_by_const_in_lines_assignment() { + let mut lines = vec![Line::Assignment { + var: "y".to_string(), + value: Expression::Value(SimpleExpr::Var("x".to_string())), + }]; + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(42)); + + replace_vars_by_const_in_lines(&mut lines, &map); + + assert_eq!( + lines, + vec![Line::Assignment { + var: "y".to_string(), + value: Expression::scalar(42), + }] + ); + } + + #[test] + fn test_replace_vars_by_const_in_lines_function_call() { + let mut lines = vec![Line::FunctionCall { + function_name: "test".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("x".to_string()))], + return_data: vec!["result".to_string()], + }]; + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(100)); + + replace_vars_by_const_in_lines(&mut lines, &map); + + assert_eq!( + lines, + vec![Line::FunctionCall { + function_name: "test".to_string(), + args: vec![Expression::scalar(100)], + return_data: vec!["result".to_string()], + }] + ); + } + + #[test] + fn test_replace_vars_by_const_in_lines_if_condition() { + let mut lines = vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::scalar(10), + }, + then_branch: vec![Line::Assignment { + var: "y".to_string(), + value: Expression::Value(SimpleExpr::Var("x".to_string())), + }], + else_branch: vec![], + }]; + let mut map = BTreeMap::new(); + map.insert("x".to_string(), crate::F::from_usize(5)); + + replace_vars_by_const_in_lines(&mut lines, &map); + + assert_eq!( + lines, + vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::scalar(5), + right: Expression::scalar(10), + }, + then_branch: vec![Line::Assignment { + var: "y".to_string(), + value: Expression::scalar(5), + }], + else_branch: vec![], + }] + ); + } +} + +/// Extract function calls from line sequence. +pub fn get_function_called(lines: &[Line], function_called: &mut Vec) { + for line in lines { + match line { + Line::Match { value: _, arms } => { + for (_, statements) in arms { + get_function_called(statements, function_called); + } + } + Line::FunctionCall { function_name, .. } => { + function_called.push(function_name.clone()); + } + Line::IfCondition { + then_branch, + else_branch, + .. + } => { + get_function_called(then_branch, function_called); + get_function_called(else_branch, function_called); + } + Line::ForLoop { body, .. } => { + get_function_called(body, function_called); + } + Line::Assignment { .. } + | Line::ArrayAssign { .. } + | Line::Assert { .. } + | Line::FunctionRet { .. } + | Line::Precompile { .. } + | Line::Print { .. } + | Line::DecomposeBits { .. } + | Line::CounterHint { .. } + | Line::MAlloc { .. } + | Line::Panic + | Line::Break + | Line::LocationReport { .. } => {} + } + } +} From 4e8b4d2407c61ee18454bbc9ee6bf3e30da7ac59 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 23:02:22 +0200 Subject: [PATCH 10/39] more tests --- crates/lean_compiler/src/ir/bytecode.rs | 335 ++++++ crates/lean_compiler/src/ir/instruction.rs | 1035 +++++++++++++++++ crates/lean_compiler/src/ir/operation.rs | 410 +++++++ crates/lean_compiler/src/ir/value.rs | 495 ++++++++ crates/lean_compiler/src/simplify/simplify.rs | 818 +++++++++++++ .../src/simplify/transformations.rs | 639 ++++++++++ 6 files changed, 3732 insertions(+) diff --git a/crates/lean_compiler/src/ir/bytecode.rs b/crates/lean_compiler/src/ir/bytecode.rs index dcfec1e9..57f90593 100644 --- a/crates/lean_compiler/src/ir/bytecode.rs +++ b/crates/lean_compiler/src/ir/bytecode.rs @@ -49,3 +49,338 @@ impl Display for IntermediateBytecode { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{instruction::IntermediateInstruction, value::IntermediateValue}; + use crate::lang::ConstExpression; + use lean_vm::Label; + + #[test] + fn test_intermediate_bytecode_creation() { + let bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + assert!(bytecode.bytecode.is_empty()); + assert!(bytecode.match_blocks.is_empty()); + assert!(bytecode.memory_size_per_function.is_empty()); + } + + #[test] + fn test_intermediate_bytecode_with_functions() { + let mut bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + let label1 = Label::function("main"); + let label2 = Label::function("helper"); + + let instructions1 = vec![ + IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }, + arg_c: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }, + }, + IntermediateInstruction::Panic, + ]; + + let instructions2 = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(4), + }, + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(5), + }, + }]; + + bytecode.bytecode.insert(label1.clone(), instructions1); + bytecode.bytecode.insert(label2.clone(), instructions2); + + bytecode + .memory_size_per_function + .insert("main".to_string(), 10); + bytecode + .memory_size_per_function + .insert("helper".to_string(), 5); + + assert_eq!(bytecode.bytecode.len(), 2); + assert_eq!(bytecode.memory_size_per_function.len(), 2); + assert!(bytecode.bytecode.contains_key(&label1)); + assert!(bytecode.bytecode.contains_key(&label2)); + assert_eq!(bytecode.memory_size_per_function.get("main"), Some(&10)); + assert_eq!(bytecode.memory_size_per_function.get("helper"), Some(&5)); + } + + #[test] + fn test_intermediate_bytecode_with_match_blocks() { + let mut bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + let case1_instructions = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(1)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(6), + }, + }]; + + let case2_instructions = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(2)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }, + }]; + + let match_block1 = vec![case1_instructions, case2_instructions]; + + let case3_instructions = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(3)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(8), + }, + }]; + + let match_block2 = vec![case3_instructions]; + + bytecode.match_blocks.push(match_block1); + bytecode.match_blocks.push(match_block2); + + assert_eq!(bytecode.match_blocks.len(), 2); + assert_eq!(bytecode.match_blocks[0].len(), 2); // Two cases in first match + assert_eq!(bytecode.match_blocks[1].len(), 1); // One case in second match + + // Check first match block structure + assert_eq!(bytecode.match_blocks[0][0].len(), 1); // One instruction in case 1 + assert_eq!(bytecode.match_blocks[0][1].len(), 1); // One instruction in case 2 + assert_eq!(bytecode.match_blocks[1][0].len(), 1); // One instruction in case 3 + } + + #[test] + fn test_intermediate_bytecode_clone() { + let mut original = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + let label = Label::function("test"); + let instructions = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(42)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(9), + }, + }]; + + original.bytecode.insert(label.clone(), instructions); + original + .memory_size_per_function + .insert("test".to_string(), 8); + + let cloned = original.clone(); + + // Verify clone has same content + assert_eq!(cloned.bytecode.len(), 1); + assert_eq!(cloned.memory_size_per_function.len(), 1); + assert!(cloned.bytecode.contains_key(&label)); + assert_eq!(cloned.memory_size_per_function.get("test"), Some(&8)); + + // Verify independence - modify original + original + .memory_size_per_function + .insert("new_function".to_string(), 16); + assert_eq!(original.memory_size_per_function.len(), 2); + assert_eq!(cloned.memory_size_per_function.len(), 1); + } + + #[test] + fn test_intermediate_bytecode_display_empty() { + let bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + let display_output = format!("{}", bytecode); + + // Should contain the header for memory size + assert!(display_output.contains("Memory size per function:")); + + // Should be minimal content for empty bytecode + assert!(display_output.lines().count() <= 3); + } + + #[test] + fn test_intermediate_bytecode_display_with_content() { + let mut bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + let label = Label::function("main"); + let instructions = vec![ + IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(42)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }, + }, + IntermediateInstruction::Panic, + ]; + + let match_case = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(100)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(11), + }, + }]; + let match_block = vec![match_case]; + + bytecode.bytecode.insert(label, instructions); + bytecode.match_blocks.push(match_block); + bytecode + .memory_size_per_function + .insert("main".to_string(), 16); + + let display_output = format!("{}", bytecode); + + // Check function label appears + assert!(display_output.contains("main:")); + + // Check instructions appear with proper indentation + assert!(display_output.contains("m[fp + 10] = 42 + 0")); + assert!(display_output.contains(" panic")); + + // Check match block appears + assert!(display_output.contains("Match 0:")); + assert!(display_output.contains("Case 0:")); + assert!(display_output.contains("m[fp + 11] = 100 + 0")); + + // Check memory size appears + assert!(display_output.contains("Memory size per function:")); + assert!(display_output.contains("main: 16")); + } + + #[test] + fn test_intermediate_bytecode_display_multiple_match_blocks() { + let mut bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + // First match block with 2 cases + let case1 = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(200)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(12), + }, + }]; + let case2 = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(300)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(13), + }, + }]; + let match_block1 = vec![case1, case2]; + + // Second match block with 1 case + let case3 = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(400)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(14), + }, + }]; + let match_block2 = vec![case3]; + + bytecode.match_blocks.push(match_block1); + bytecode.match_blocks.push(match_block2); + + let display_output = format!("{}", bytecode); + + // Check both match blocks appear + assert!(display_output.contains("Match 0:")); + assert!(display_output.contains("Match 1:")); + + // Check cases in first match + assert!(display_output.contains("Case 0:")); + assert!(display_output.contains("Case 1:")); + assert!(display_output.contains("m[fp + 12] = 200 + 0")); + assert!(display_output.contains("m[fp + 13] = 300 + 0")); + + // Check case in second match + assert!(display_output.contains("m[fp + 14] = 400 + 0")); + } + + #[test] + fn test_intermediate_bytecode_debug_format() { + let mut bytecode = IntermediateBytecode { + bytecode: BTreeMap::new(), + match_blocks: Vec::new(), + memory_size_per_function: BTreeMap::new(), + }; + + let label = Label::function("test"); + let instructions = vec![IntermediateInstruction::Computation { + operation: lean_vm::Operation::Add, + arg_a: IntermediateValue::Constant(ConstExpression::scalar(500)), + arg_c: IntermediateValue::Constant(ConstExpression::scalar(0)), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(15), + }, + }]; + + bytecode.bytecode.insert(label, instructions); + bytecode + .memory_size_per_function + .insert("test".to_string(), 4); + + let debug_output = format!("{:?}", bytecode); + + // Debug format should contain struct name and fields + assert!(debug_output.contains("IntermediateBytecode")); + assert!(debug_output.contains("bytecode:")); + assert!(debug_output.contains("match_blocks:")); + assert!(debug_output.contains("memory_size_per_function:")); + + // Should contain the actual data + assert!(debug_output.contains("test")); + assert!(debug_output.contains("500")); + assert!(debug_output.contains("15")); + } +} diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index a7a6f091..5b235d7d 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -226,3 +226,1038 @@ impl Display for IntermediateInstruction { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::value::{IntermediaryMemOrFpOrConstant, IntermediateValue}; + use crate::lang::ConstExpression; + use lean_vm::{Operation, SourceLineNumber}; + + #[test] + fn test_computation_instruction() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }; + let arg_c = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + + let instruction = IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: arg_a.clone(), + arg_c: arg_c.clone(), + res: res.clone(), + }; + + if let IntermediateInstruction::Computation { + operation, + arg_a: a, + arg_c: c, + res: r, + } = instruction + { + assert_eq!(operation, Operation::Add); + assert_eq!(a, arg_a); + assert_eq!(c, arg_c); + assert_eq!(r, res); + } else { + panic!("Expected Computation variant"); + } + } + + #[test] + fn test_computation_add_operation() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }; + let arg_c = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + + let instruction = IntermediateInstruction::computation( + HighLevelOperation::Add, + arg_a.clone(), + arg_c.clone(), + res.clone(), + ); + + if let IntermediateInstruction::Computation { + operation, + arg_a: a, + arg_c: c, + res: r, + } = instruction + { + assert_eq!(operation, Operation::Add); + assert_eq!(a, arg_a); + assert_eq!(c, arg_c); + assert_eq!(r, res); + } else { + panic!("Expected Computation variant"); + } + } + + #[test] + fn test_computation_mul_operation() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }; + let arg_c = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + + let instruction = IntermediateInstruction::computation( + HighLevelOperation::Mul, + arg_a.clone(), + arg_c.clone(), + res.clone(), + ); + + if let IntermediateInstruction::Computation { + operation, + arg_a: a, + arg_c: c, + res: r, + } = instruction + { + assert_eq!(operation, Operation::Mul); + assert_eq!(a, arg_a); + assert_eq!(c, arg_c); + assert_eq!(r, res); + } else { + panic!("Expected Computation variant"); + } + } + + #[test] + fn test_computation_sub_operation() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }; + let arg_c = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + + let instruction = IntermediateInstruction::computation( + HighLevelOperation::Sub, + arg_a.clone(), + arg_c.clone(), + res.clone(), + ); + + // Sub is translated to: res = arg_a + arg_c => arg_a = res + arg_c + if let IntermediateInstruction::Computation { + operation, + arg_a: a, + arg_c: c, + res: r, + } = instruction + { + assert_eq!(operation, Operation::Add); + assert_eq!(a, res); // result becomes arg_a + assert_eq!(c, arg_c); // arg_c stays the same + assert_eq!(r, arg_a); // original arg_a becomes result + } else { + panic!("Expected Computation variant"); + } + } + + #[test] + fn test_computation_div_operation() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }; + let arg_c = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + + let instruction = IntermediateInstruction::computation( + HighLevelOperation::Div, + arg_a.clone(), + arg_c.clone(), + res.clone(), + ); + + // Div is translated to: res = arg_a * arg_c => arg_a = res * arg_c + if let IntermediateInstruction::Computation { + operation, + arg_a: a, + arg_c: c, + res: r, + } = instruction + { + assert_eq!(operation, Operation::Mul); + assert_eq!(a, res); // result becomes arg_a + assert_eq!(c, arg_c); // arg_c stays the same + assert_eq!(r, arg_a); // original arg_a becomes result + } else { + panic!("Expected Computation variant"); + } + } + + #[test] + fn test_equality_instruction() { + let left = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(5), + }; + let right = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(6), + }; + + let instruction = IntermediateInstruction::equality(left.clone(), right.clone()); + + if let IntermediateInstruction::Computation { + operation, + arg_a, + arg_c, + res, + } = instruction + { + assert_eq!(operation, Operation::Add); + assert_eq!(arg_a, left); + assert_eq!(arg_c, IntermediateValue::Constant(ConstExpression::zero())); + assert_eq!(res, right); + } else { + panic!("Expected Computation variant"); + } + } + + #[test] + fn test_deref_instruction() { + let shift_0 = ConstExpression::scalar(5); + let shift_1 = ConstExpression::scalar(10); + let res = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(15), + }; + + let instruction = IntermediateInstruction::Deref { + shift_0: shift_0.clone(), + shift_1: shift_1.clone(), + res: res.clone(), + }; + + if let IntermediateInstruction::Deref { + shift_0: s0, + shift_1: s1, + res: r, + } = instruction + { + assert_eq!(s0, shift_0); + assert_eq!(s1, shift_1); + assert_eq!(r, res); + } else { + panic!("Expected Deref variant"); + } + } + + #[test] + fn test_panic_instruction() { + let instruction = IntermediateInstruction::Panic; + assert!(matches!(instruction, IntermediateInstruction::Panic)); + } + + #[test] + fn test_jump_instruction_without_fp() { + let dest = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }; + let instruction = IntermediateInstruction::Jump { + dest: dest.clone(), + updated_fp: None, + }; + + if let IntermediateInstruction::Jump { + dest: d, + updated_fp, + } = instruction + { + assert_eq!(d, dest); + assert!(updated_fp.is_none()); + } else { + panic!("Expected Jump variant"); + } + } + + #[test] + fn test_jump_instruction_with_fp() { + let dest = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }; + let fp = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(8), + }; + let instruction = IntermediateInstruction::Jump { + dest: dest.clone(), + updated_fp: Some(fp.clone()), + }; + + if let IntermediateInstruction::Jump { + dest: d, + updated_fp, + } = instruction + { + assert_eq!(d, dest); + assert_eq!(updated_fp, Some(fp)); + } else { + panic!("Expected Jump variant"); + } + } + + #[test] + fn test_jump_if_not_zero_instruction() { + let condition = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(4), + }; + let dest = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }; + let fp = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(8), + }; + + let instruction = IntermediateInstruction::JumpIfNotZero { + condition: condition.clone(), + dest: dest.clone(), + updated_fp: Some(fp.clone()), + }; + + if let IntermediateInstruction::JumpIfNotZero { + condition: c, + dest: d, + updated_fp, + } = instruction + { + assert_eq!(c, condition); + assert_eq!(d, dest); + assert_eq!(updated_fp, Some(fp)); + } else { + panic!("Expected JumpIfNotZero variant"); + } + } + + #[test] + fn test_poseidon2_16_instruction() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(15), + }; + let arg_b = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(16), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(9), + }; + + let instruction = IntermediateInstruction::Poseidon2_16 { + arg_a: arg_a.clone(), + arg_b: arg_b.clone(), + res: res.clone(), + }; + + if let IntermediateInstruction::Poseidon2_16 { + arg_a: a, + arg_b: b, + res: r, + } = instruction + { + assert_eq!(a, arg_a); + assert_eq!(b, arg_b); + assert_eq!(r, res); + } else { + panic!("Expected Poseidon2_16 variant"); + } + } + + #[test] + fn test_poseidon2_24_instruction() { + let arg_a = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(15), + }; + let arg_b = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(16), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(9), + }; + + let instruction = IntermediateInstruction::Poseidon2_24 { + arg_a: arg_a.clone(), + arg_b: arg_b.clone(), + res: res.clone(), + }; + + if let IntermediateInstruction::Poseidon2_24 { + arg_a: a, + arg_b: b, + res: r, + } = instruction + { + assert_eq!(a, arg_a); + assert_eq!(b, arg_b); + assert_eq!(r, res); + } else { + panic!("Expected Poseidon2_24 variant"); + } + } + + #[test] + fn test_dot_product_instruction() { + let arg0 = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(17), + }; + let arg1 = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(18), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + let size = ConstExpression::scalar(8); + + let instruction = IntermediateInstruction::DotProduct { + arg0: arg0.clone(), + arg1: arg1.clone(), + res: res.clone(), + size: size.clone(), + }; + + if let IntermediateInstruction::DotProduct { + arg0: a0, + arg1: a1, + res: r, + size: s, + } = instruction + { + assert_eq!(a0, arg0); + assert_eq!(a1, arg1); + assert_eq!(r, res); + assert_eq!(s, size); + } else { + panic!("Expected DotProduct variant"); + } + } + + #[test] + fn test_multilinear_eval_instruction() { + let coeffs = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let point = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(11), + }; + let res = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }; + let n_vars = ConstExpression::scalar(4); + + let instruction = IntermediateInstruction::MultilinearEval { + coeffs: coeffs.clone(), + point: point.clone(), + res: res.clone(), + n_vars: n_vars.clone(), + }; + + if let IntermediateInstruction::MultilinearEval { + coeffs: c, + point: p, + res: r, + n_vars: n, + } = instruction + { + assert_eq!(c, coeffs); + assert_eq!(p, point); + assert_eq!(r, res); + assert_eq!(n, n_vars); + } else { + panic!("Expected MultilinearEval variant"); + } + } + + #[test] + fn test_inverse_instruction() { + let arg = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(12), + }; + let res_offset = 42; + + let instruction = IntermediateInstruction::Inverse { + arg: arg.clone(), + res_offset, + }; + + if let IntermediateInstruction::Inverse { + arg: a, + res_offset: r, + } = instruction + { + assert_eq!(a, arg); + assert_eq!(r, res_offset); + } else { + panic!("Expected Inverse variant"); + } + } + + #[test] + fn test_request_memory_instruction_non_vectorized() { + let offset = ConstExpression::scalar(10); + let size = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(13), + }; + let vectorized_len = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(14), + }; + + let instruction = IntermediateInstruction::RequestMemory { + offset: offset.clone(), + size: size.clone(), + vectorized: false, + vectorized_len: vectorized_len.clone(), + }; + + if let IntermediateInstruction::RequestMemory { + offset: o, + size: s, + vectorized, + vectorized_len: vl, + } = instruction + { + assert_eq!(o, offset); + assert_eq!(s, size); + assert!(!vectorized); + assert_eq!(vl, vectorized_len); + } else { + panic!("Expected RequestMemory variant"); + } + } + + #[test] + fn test_request_memory_instruction_vectorized() { + let offset = ConstExpression::scalar(10); + let size = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(13), + }; + let vectorized_len = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(14), + }; + + let instruction = IntermediateInstruction::RequestMemory { + offset: offset.clone(), + size: size.clone(), + vectorized: true, + vectorized_len: vectorized_len.clone(), + }; + + if let IntermediateInstruction::RequestMemory { + offset: o, + size: s, + vectorized, + vectorized_len: vl, + } = instruction + { + assert_eq!(o, offset); + assert_eq!(s, size); + assert!(vectorized); + assert_eq!(vl, vectorized_len); + } else { + panic!("Expected RequestMemory variant"); + } + } + + #[test] + fn test_decompose_bits_instruction() { + let res_offset = 20; + let to_decompose = vec![ + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(20), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(21), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(22), + }, + ]; + + let instruction = IntermediateInstruction::DecomposeBits { + res_offset, + to_decompose: to_decompose.clone(), + }; + + if let IntermediateInstruction::DecomposeBits { + res_offset: r, + to_decompose: td, + } = instruction + { + assert_eq!(r, res_offset); + assert_eq!(td, to_decompose); + assert_eq!(td.len(), 3); + } else { + panic!("Expected DecomposeBits variant"); + } + } + + #[test] + fn test_counter_hint_instruction() { + let res_offset = 15; + + let instruction = IntermediateInstruction::CounterHint { res_offset }; + + if let IntermediateInstruction::CounterHint { res_offset: r } = instruction { + assert_eq!(r, res_offset); + } else { + panic!("Expected CounterHint variant"); + } + } + + #[test] + fn test_print_instruction() { + let line_info = "line 42".to_string(); + let content = vec![ + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(23), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(24), + }, + ]; + + let instruction = IntermediateInstruction::Print { + line_info: line_info.clone(), + content: content.clone(), + }; + + if let IntermediateInstruction::Print { + line_info: li, + content: c, + } = instruction + { + assert_eq!(li, line_info); + assert_eq!(c, content); + assert_eq!(c.len(), 2); + } else { + panic!("Expected Print variant"); + } + } + + #[test] + fn test_location_report_instruction() { + let location: SourceLineNumber = 123; + + let instruction = IntermediateInstruction::LocationReport { location }; + + if let IntermediateInstruction::LocationReport { location: l } = instruction { + assert_eq!(l, location); + } else { + panic!("Expected LocationReport variant"); + } + } + + #[test] + fn test_instruction_clone() { + let original = IntermediateInstruction::Panic; + let cloned = original.clone(); + assert!(matches!(cloned, IntermediateInstruction::Panic)); + + let complex_original = IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }, + arg_c: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(3), + }, + }; + let complex_cloned = complex_original.clone(); + + if let ( + IntermediateInstruction::Computation { + operation: op1, + arg_a: a1, + arg_c: c1, + res: r1, + }, + IntermediateInstruction::Computation { + operation: op2, + arg_a: a2, + arg_c: c2, + res: r2, + }, + ) = (&complex_original, &complex_cloned) + { + assert_eq!(op1, op2); + assert_eq!(a1, a2); + assert_eq!(c1, c2); + assert_eq!(r1, r2); + } + } + + #[test] + fn test_instruction_debug_format() { + let instruction = IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }, + arg_c: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(26), + }, + }; + + let debug_output = format!("{:?}", instruction); + assert!(debug_output.contains("Computation")); + assert!(debug_output.contains("Add")); + assert!(debug_output.contains("1")); + assert!(debug_output.contains("2")); + assert!(debug_output.contains("26")); + } + + #[test] + fn test_display_computation() { + let instruction = IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }, + arg_c: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(26), + }, + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "m[fp + 26] = m[fp + 1] + m[fp + 2]"); + } + + #[test] + fn test_display_deref() { + let instruction = IntermediateInstruction::Deref { + shift_0: ConstExpression::scalar(5), + shift_1: ConstExpression::scalar(10), + res: IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(15), + }, + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "m[fp + 15] = m[m[fp + 5] + 10]"); + } + + #[test] + fn test_display_panic() { + let instruction = IntermediateInstruction::Panic; + let display_output = format!("{}", instruction); + assert_eq!(display_output, "panic"); + } + + #[test] + fn test_display_jump_without_fp() { + let instruction = IntermediateInstruction::Jump { + dest: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }, + updated_fp: None, + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "jump m[fp + 7]"); + } + + #[test] + fn test_display_jump_with_fp() { + let instruction = IntermediateInstruction::Jump { + dest: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }, + updated_fp: Some(IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(8), + }), + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "jump m[fp + 7] with fp = m[fp + 8]"); + } + + #[test] + fn test_display_jump_if_not_zero_without_fp() { + let instruction = IntermediateInstruction::JumpIfNotZero { + condition: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(4), + }, + dest: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }, + updated_fp: None, + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "jump_if_not_zero m[fp + 4] to m[fp + 7]"); + } + + #[test] + fn test_display_jump_if_not_zero_with_fp() { + let instruction = IntermediateInstruction::JumpIfNotZero { + condition: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(4), + }, + dest: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(7), + }, + updated_fp: Some(IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(8), + }), + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "jump_if_not_zero m[fp + 4] to m[fp + 7] with fp = m[fp + 8]" + ); + } + + #[test] + fn test_display_poseidon2_16() { + let instruction = IntermediateInstruction::Poseidon2_16 { + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(27), + }, + arg_b: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(28), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(29), + }, + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "m[fp + 29] = poseidon2_16(m[fp + 27], m[fp + 28])" + ); + } + + #[test] + fn test_display_poseidon2_24() { + let instruction = IntermediateInstruction::Poseidon2_24 { + arg_a: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(27), + }, + arg_b: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(28), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(29), + }, + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "m[fp + 29] = poseidon2_24(m[fp + 27], m[fp + 28])" + ); + } + + #[test] + fn test_display_dot_product() { + let instruction = IntermediateInstruction::DotProduct { + arg0: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(30), + }, + arg1: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(31), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(26), + }, + size: ConstExpression::scalar(8), + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "dot_product(m[fp + 30], m[fp + 31], m[fp + 26], 8)" + ); + } + + #[test] + fn test_display_multilinear_eval() { + let instruction = IntermediateInstruction::MultilinearEval { + coeffs: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(32), + }, + point: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(33), + }, + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(26), + }, + n_vars: ConstExpression::scalar(4), + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "multilinear_eval(m[fp + 32], m[fp + 33], m[fp + 26], 4)" + ); + } + + #[test] + fn test_display_inverse() { + let instruction = IntermediateInstruction::Inverse { + arg: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(34), + }, + res_offset: 42, + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "m[fp + 42] = inverse(m[fp + 34])"); + } + + #[test] + fn test_display_request_memory_non_vectorized() { + let instruction = IntermediateInstruction::RequestMemory { + offset: ConstExpression::scalar(10), + size: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(35), + }, + vectorized: false, + vectorized_len: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(36), + }, + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "m[fp + 10] = request_memory(m[fp + 35])"); + } + + #[test] + fn test_display_request_memory_vectorized() { + let instruction = IntermediateInstruction::RequestMemory { + offset: ConstExpression::scalar(10), + size: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(35), + }, + vectorized: true, + vectorized_len: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(36), + }, + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "m[fp + 10] = request_memory_vec(m[fp + 35], m[fp + 36])" + ); + } + + #[test] + fn test_display_decompose_bits() { + let instruction = IntermediateInstruction::DecomposeBits { + res_offset: 20, + to_decompose: vec![ + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(37), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(38), + }, + ], + }; + + let display_output = format!("{}", instruction); + assert_eq!( + display_output, + "m[fp + 20..] = decompose_bits(m[fp + 37], m[fp + 38])" + ); + } + + #[test] + fn test_display_counter_hint() { + let instruction = IntermediateInstruction::CounterHint { res_offset: 15 }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "m[fp + 15] = counter_hint()"); + } + + #[test] + fn test_display_print() { + let instruction = IntermediateInstruction::Print { + line_info: "line 42".to_string(), + content: vec![ + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(23), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(24), + }, + ], + }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, "print line 42: m[fp + 23], m[fp + 24]"); + } + + #[test] + fn test_display_location_report() { + let instruction = IntermediateInstruction::LocationReport { location: 123 }; + + let display_output = format!("{}", instruction); + assert_eq!(display_output, ""); + } + + #[test] + #[should_panic(expected = "unreachable")] + fn test_computation_exp_unreachable() { + IntermediateInstruction::computation( + HighLevelOperation::Exp, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(26), + }, + ); + } + + #[test] + #[should_panic(expected = "unreachable")] + fn test_computation_mod_unreachable() { + IntermediateInstruction::computation( + HighLevelOperation::Mod, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(2), + }, + IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(26), + }, + ); + } +} diff --git a/crates/lean_compiler/src/ir/operation.rs b/crates/lean_compiler/src/ir/operation.rs index 39325468..676eaf12 100644 --- a/crates/lean_compiler/src/ir/operation.rs +++ b/crates/lean_compiler/src/ir/operation.rs @@ -62,3 +62,413 @@ impl TryFrom for Operation { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::F; + + #[test] + fn test_high_level_operation_enum_variants() { + let add = HighLevelOperation::Add; + let mul = HighLevelOperation::Mul; + let sub = HighLevelOperation::Sub; + let div = HighLevelOperation::Div; + let exp = HighLevelOperation::Exp; + let mod_op = HighLevelOperation::Mod; + + assert_eq!(add, HighLevelOperation::Add); + assert_eq!(mul, HighLevelOperation::Mul); + assert_eq!(sub, HighLevelOperation::Sub); + assert_eq!(div, HighLevelOperation::Div); + assert_eq!(exp, HighLevelOperation::Exp); + assert_eq!(mod_op, HighLevelOperation::Mod); + } + + #[test] + fn test_high_level_operation_clone() { + let original = HighLevelOperation::Add; + let cloned = original.clone(); + assert_eq!(original, cloned); + } + + #[test] + fn test_high_level_operation_copy() { + let original = HighLevelOperation::Mul; + let copied = original; + assert_eq!(original, copied); + } + + #[test] + fn test_high_level_operation_partial_eq() { + assert_eq!(HighLevelOperation::Add, HighLevelOperation::Add); + assert_ne!(HighLevelOperation::Add, HighLevelOperation::Mul); + } + + #[test] + fn test_high_level_operation_partial_ord() { + let add = HighLevelOperation::Add; + let mul = HighLevelOperation::Mul; + + assert!(add <= mul); + assert!(add < mul); + assert!(mul > add); + assert!(mul >= add); + } + + #[test] + fn test_high_level_operation_hash() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher1 = DefaultHasher::new(); + let mut hasher2 = DefaultHasher::new(); + + HighLevelOperation::Add.hash(&mut hasher1); + HighLevelOperation::Add.hash(&mut hasher2); + + assert_eq!(hasher1.finish(), hasher2.finish()); + } + + #[test] + fn test_eval_add_operation() { + let op = HighLevelOperation::Add; + let a = F::from_usize(5); + let b = F::from_usize(3); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(8)); + } + + #[test] + fn test_eval_mul_operation() { + let op = HighLevelOperation::Mul; + let a = F::from_usize(4); + let b = F::from_usize(7); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(28)); + } + + #[test] + fn test_eval_sub_operation() { + let op = HighLevelOperation::Sub; + let a = F::from_usize(10); + let b = F::from_usize(3); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(7)); + } + + #[test] + fn test_eval_div_operation() { + let op = HighLevelOperation::Div; + let a = F::from_usize(15); + let b = F::from_usize(3); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(5)); + } + + #[test] + fn test_eval_exp_operation() { + let op = HighLevelOperation::Exp; + let a = F::from_usize(2); + let b = F::from_usize(3); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(8)); + } + + #[test] + fn test_eval_mod_operation() { + let op = HighLevelOperation::Mod; + let a = F::from_usize(17); + let b = F::from_usize(5); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(2)); + } + + #[test] + fn test_eval_add_with_zero() { + let op = HighLevelOperation::Add; + let a = F::from_usize(42); + let b = F::ZERO; + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(42)); + } + + #[test] + fn test_eval_mul_with_zero() { + let op = HighLevelOperation::Mul; + let a = F::from_usize(42); + let b = F::ZERO; + let result = op.eval(a, b); + assert_eq!(result, F::ZERO); + } + + #[test] + fn test_eval_mul_with_one() { + let op = HighLevelOperation::Mul; + let a = F::from_usize(42); + let b = F::ONE; + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(42)); + } + + #[test] + fn test_eval_exp_with_zero_exponent() { + let op = HighLevelOperation::Exp; + let a = F::from_usize(42); + let b = F::ZERO; + let result = op.eval(a, b); + assert_eq!(result, F::ONE); + } + + #[test] + fn test_eval_exp_with_one_exponent() { + let op = HighLevelOperation::Exp; + let a = F::from_usize(42); + let b = F::ONE; + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(42)); + } + + #[test] + fn test_display_add() { + let op = HighLevelOperation::Add; + assert_eq!(format!("{}", op), "+"); + } + + #[test] + fn test_display_mul() { + let op = HighLevelOperation::Mul; + assert_eq!(format!("{}", op), "*"); + } + + #[test] + fn test_display_sub() { + let op = HighLevelOperation::Sub; + assert_eq!(format!("{}", op), "-"); + } + + #[test] + fn test_display_div() { + let op = HighLevelOperation::Div; + assert_eq!(format!("{}", op), "/"); + } + + #[test] + fn test_display_exp() { + let op = HighLevelOperation::Exp; + assert_eq!(format!("{}", op), "**"); + } + + #[test] + fn test_display_mod() { + let op = HighLevelOperation::Mod; + assert_eq!(format!("{}", op), "%"); + } + + #[test] + fn test_try_from_add_success() { + let high_level = HighLevelOperation::Add; + let vm_op = Operation::try_from(high_level); + assert!(vm_op.is_ok()); + assert_eq!(vm_op.unwrap(), Operation::Add); + } + + #[test] + fn test_try_from_mul_success() { + let high_level = HighLevelOperation::Mul; + let vm_op = Operation::try_from(high_level); + assert!(vm_op.is_ok()); + assert_eq!(vm_op.unwrap(), Operation::Mul); + } + + #[test] + fn test_try_from_sub_failure() { + let high_level = HighLevelOperation::Sub; + let vm_op = Operation::try_from(high_level); + assert!(vm_op.is_err()); + assert_eq!(vm_op.unwrap_err(), "Cannot convert Sub to +/x"); + } + + #[test] + fn test_try_from_div_failure() { + let high_level = HighLevelOperation::Div; + let vm_op = Operation::try_from(high_level); + assert!(vm_op.is_err()); + assert_eq!(vm_op.unwrap_err(), "Cannot convert Div to +/x"); + } + + #[test] + fn test_try_from_exp_failure() { + let high_level = HighLevelOperation::Exp; + let vm_op = Operation::try_from(high_level); + assert!(vm_op.is_err()); + assert_eq!(vm_op.unwrap_err(), "Cannot convert Exp to +/x"); + } + + #[test] + fn test_try_from_mod_failure() { + let high_level = HighLevelOperation::Mod; + let vm_op = Operation::try_from(high_level); + assert!(vm_op.is_err()); + assert_eq!(vm_op.unwrap_err(), "Cannot convert Mod to +/x"); + } + + #[test] + fn test_debug_format_add() { + let op = HighLevelOperation::Add; + assert_eq!(format!("{:?}", op), "Add"); + } + + #[test] + fn test_debug_format_mul() { + let op = HighLevelOperation::Mul; + assert_eq!(format!("{:?}", op), "Mul"); + } + + #[test] + fn test_debug_format_sub() { + let op = HighLevelOperation::Sub; + assert_eq!(format!("{:?}", op), "Sub"); + } + + #[test] + fn test_debug_format_div() { + let op = HighLevelOperation::Div; + assert_eq!(format!("{:?}", op), "Div"); + } + + #[test] + fn test_debug_format_exp() { + let op = HighLevelOperation::Exp; + assert_eq!(format!("{:?}", op), "Exp"); + } + + #[test] + fn test_debug_format_mod() { + let op = HighLevelOperation::Mod; + assert_eq!(format!("{:?}", op), "Mod"); + } + + #[test] + fn test_eval_large_numbers() { + let op = HighLevelOperation::Add; + let a = F::from_usize(1000000); + let b = F::from_usize(2000000); + let result = op.eval(a, b); + assert_eq!(result, F::from_usize(3000000)); + } + + #[test] + fn test_eval_edge_case_mod_by_one() { + let op = HighLevelOperation::Mod; + let a = F::from_usize(42); + let b = F::ONE; + let result = op.eval(a, b); + assert_eq!(result, F::ZERO); + } + + #[test] + fn test_operation_ordering_consistency() { + let operations = [ + HighLevelOperation::Add, + HighLevelOperation::Mul, + HighLevelOperation::Sub, + HighLevelOperation::Div, + HighLevelOperation::Exp, + HighLevelOperation::Mod, + ]; + + for i in 0..operations.len() { + for j in 0..operations.len() { + if i < j { + assert!(operations[i] < operations[j]); + } else if i == j { + assert!(operations[i] == operations[j]); + } else { + assert!(operations[i] > operations[j]); + } + } + } + } + + #[test] + fn test_eval_commutativity_add() { + let op = HighLevelOperation::Add; + let a = F::from_usize(7); + let b = F::from_usize(13); + + let result1 = op.eval(a, b); + let result2 = op.eval(b, a); + assert_eq!(result1, result2); + } + + #[test] + fn test_eval_commutativity_mul() { + let op = HighLevelOperation::Mul; + let a = F::from_usize(7); + let b = F::from_usize(13); + + let result1 = op.eval(a, b); + let result2 = op.eval(b, a); + assert_eq!(result1, result2); + } + + #[test] + fn test_eval_non_commutativity_sub() { + let op = HighLevelOperation::Sub; + let a = F::from_usize(10); + let b = F::from_usize(3); + + let result1 = op.eval(a, b); + let result2 = op.eval(b, a); + assert_ne!(result1, result2); + } + + #[test] + fn test_eval_non_commutativity_div() { + let op = HighLevelOperation::Div; + let a = F::from_usize(15); + let b = F::from_usize(3); + + let result1 = op.eval(a, b); + let result2 = op.eval(b, a); + assert_ne!(result1, result2); + } + + #[test] + fn test_all_variants_covered_by_display() { + let operations = [ + HighLevelOperation::Add, + HighLevelOperation::Mul, + HighLevelOperation::Sub, + HighLevelOperation::Div, + HighLevelOperation::Exp, + HighLevelOperation::Mod, + ]; + + let expected_displays = ["+", "*", "-", "/", "**", "%"]; + + for (op, expected) in operations.iter().zip(expected_displays.iter()) { + assert_eq!(format!("{}", op), *expected); + } + } + + #[test] + fn test_all_variants_covered_by_eval() { + let operations = [ + HighLevelOperation::Add, + HighLevelOperation::Mul, + HighLevelOperation::Sub, + HighLevelOperation::Div, + HighLevelOperation::Exp, + HighLevelOperation::Mod, + ]; + + let a = F::from_usize(10); + let b = F::from_usize(3); + + for op in operations.iter() { + let _result = op.eval(a, b); + } + } +} diff --git a/crates/lean_compiler/src/ir/value.rs b/crates/lean_compiler/src/ir/value.rs index 9a7d61b5..d9b0fa18 100644 --- a/crates/lean_compiler/src/ir/value.rs +++ b/crates/lean_compiler/src/ir/value.rs @@ -70,3 +70,498 @@ impl Display for IntermediaryMemOrFpOrConstant { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::ConstExpression; + use lean_vm::Label; + + #[test] + fn test_intermediate_value_constant_variant() { + let const_expr = ConstExpression::scalar(42); + let value = IntermediateValue::Constant(const_expr.clone()); + + if let IntermediateValue::Constant(inner) = &value { + assert_eq!(inner, &const_expr); + } else { + panic!("Expected Constant variant"); + } + } + + #[test] + fn test_intermediate_value_fp_variant() { + let value = IntermediateValue::Fp; + assert_eq!(value, IntermediateValue::Fp); + } + + #[test] + fn test_intermediate_value_memory_after_fp_variant() { + let offset = ConstExpression::scalar(10); + let value = IntermediateValue::MemoryAfterFp { + offset: offset.clone(), + }; + + if let IntermediateValue::MemoryAfterFp { + offset: inner_offset, + } = &value + { + assert_eq!(inner_offset, &offset); + } else { + panic!("Expected MemoryAfterFp variant"); + } + } + + #[test] + fn test_intermediate_value_clone() { + let original = IntermediateValue::Constant(ConstExpression::scalar(42)); + let cloned = original.clone(); + assert_eq!(original, cloned); + } + + #[test] + fn test_intermediate_value_partial_eq() { + let value1 = IntermediateValue::Constant(ConstExpression::scalar(42)); + let value2 = IntermediateValue::Constant(ConstExpression::scalar(42)); + let value3 = IntermediateValue::Constant(ConstExpression::scalar(43)); + + assert_eq!(value1, value2); + assert_ne!(value1, value3); + assert_ne!(value1, IntermediateValue::Fp); + } + + #[test] + fn test_intermediate_value_label_method() { + let label = Label::function("test_function"); + let value = IntermediateValue::label(label.clone()); + + if let IntermediateValue::Constant(const_expr) = &value { + if let ConstExpression::Value(crate::lang::ConstantValue::Label(inner_label)) = + const_expr + { + assert_eq!(inner_label, &label); + } else { + panic!("Expected Label within ConstExpression"); + } + } else { + panic!("Expected Constant variant"); + } + } + + #[test] + fn test_intermediate_value_is_constant_true() { + let value = IntermediateValue::Constant(ConstExpression::scalar(42)); + assert!(value.is_constant()); + } + + #[test] + fn test_intermediate_value_is_constant_false_fp() { + let value = IntermediateValue::Fp; + assert!(!value.is_constant()); + } + + #[test] + fn test_intermediate_value_is_constant_false_memory() { + let value = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + assert!(!value.is_constant()); + } + + #[test] + fn test_intermediate_value_from_const_expression() { + let const_expr = ConstExpression::scalar(42); + let value: IntermediateValue = const_expr.clone().into(); + + if let IntermediateValue::Constant(inner) = &value { + assert_eq!(inner, &const_expr); + } else { + panic!("Expected Constant variant"); + } + } + + #[test] + fn test_intermediate_value_from_label() { + let label = Label::function("test_function"); + let value: IntermediateValue = label.clone().into(); + + if let IntermediateValue::Constant(const_expr) = &value { + if let ConstExpression::Value(crate::lang::ConstantValue::Label(inner_label)) = + const_expr + { + assert_eq!(inner_label, &label); + } else { + panic!("Expected Label within ConstExpression"); + } + } else { + panic!("Expected Constant variant"); + } + } + + #[test] + fn test_intermediate_value_display_constant() { + let value = IntermediateValue::Constant(ConstExpression::scalar(42)); + assert_eq!(format!("{}", value), "42"); + } + + #[test] + fn test_intermediate_value_display_fp() { + let value = IntermediateValue::Fp; + assert_eq!(format!("{}", value), "fp"); + } + + #[test] + fn test_intermediate_value_display_memory_after_fp() { + let value = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + assert_eq!(format!("{}", value), "m[fp + 10]"); + } + + #[test] + fn test_intermediate_value_debug_format() { + let value = IntermediateValue::Constant(ConstExpression::scalar(42)); + let debug_output = format!("{:?}", value); + + assert!(debug_output.contains("Constant")); + assert!(debug_output.contains("42")); + } + + #[test] + fn test_intermediate_value_equality_different_variants() { + let constant = IntermediateValue::Constant(ConstExpression::scalar(42)); + let fp = IntermediateValue::Fp; + let memory = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(42), + }; + + assert_ne!(constant, fp); + assert_ne!(constant, memory); + assert_ne!(fp, memory); + } + + #[test] + fn test_intermediate_value_equality_same_memory_different_offset() { + let memory1 = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let memory2 = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(20), + }; + + assert_ne!(memory1, memory2); + } + + #[test] + fn test_intermediate_value_equality_same_memory_same_offset() { + let memory1 = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let memory2 = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + + assert_eq!(memory1, memory2); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_memory_after_fp() { + let offset = ConstExpression::scalar(10); + let value = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: offset.clone(), + }; + + if let IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: inner_offset, + } = &value + { + assert_eq!(inner_offset, &offset); + } else { + panic!("Expected MemoryAfterFp variant"); + } + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_fp() { + let value = IntermediaryMemOrFpOrConstant::Fp; + assert_eq!(value, IntermediaryMemOrFpOrConstant::Fp); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_constant() { + let const_expr = ConstExpression::scalar(42); + let value = IntermediaryMemOrFpOrConstant::Constant(const_expr.clone()); + + if let IntermediaryMemOrFpOrConstant::Constant(inner) = &value { + assert_eq!(inner, &const_expr); + } else { + panic!("Expected Constant variant"); + } + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_clone() { + let original = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + let cloned = original.clone(); + assert_eq!(original, cloned); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_partial_eq() { + let value1 = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + let value2 = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + let value3 = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(43)); + + assert_eq!(value1, value2); + assert_ne!(value1, value3); + assert_ne!(value1, IntermediaryMemOrFpOrConstant::Fp); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_partial_ord() { + let memory = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let fp = IntermediaryMemOrFpOrConstant::Fp; + let constant = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + + assert!(memory < fp); + assert!(fp < constant); + assert!(memory < constant); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_ord() { + use std::cmp::Ordering; + + let memory = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let fp = IntermediaryMemOrFpOrConstant::Fp; + + assert_eq!(memory.cmp(&fp), Ordering::Less); + assert_eq!(fp.cmp(&memory), Ordering::Greater); + assert_eq!(memory.cmp(&memory), Ordering::Equal); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_hash() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher1 = DefaultHasher::new(); + let mut hasher2 = DefaultHasher::new(); + + let value = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + value.hash(&mut hasher1); + value.hash(&mut hasher2); + + assert_eq!(hasher1.finish(), hasher2.finish()); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_display_memory() { + let value = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + assert_eq!(format!("{}", value), "m[fp + 10]"); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_display_fp() { + let value = IntermediaryMemOrFpOrConstant::Fp; + assert_eq!(format!("{}", value), "fp"); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_display_constant() { + let value = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + assert_eq!(format!("{}", value), "42"); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_debug_format() { + let value = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + let debug_output = format!("{:?}", value); + + assert!(debug_output.contains("Constant")); + assert!(debug_output.contains("42")); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_equality_different_variants() { + let memory = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let fp = IntermediaryMemOrFpOrConstant::Fp; + let constant = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + + assert_ne!(memory, fp); + assert_ne!(memory, constant); + assert_ne!(fp, constant); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_equality_same_memory_different_offset() { + let memory1 = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let memory2 = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(20), + }; + + assert_ne!(memory1, memory2); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_equality_same_memory_same_offset() { + let memory1 = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let memory2 = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + + assert_eq!(memory1, memory2); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_ordering_consistency() { + let values = [ + IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(5), + }, + IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }, + IntermediaryMemOrFpOrConstant::Fp, + IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(1)), + IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)), + ]; + + for i in 0..values.len() { + for j in 0..values.len() { + if i < j { + assert!(values[i] < values[j], "Expected {} < {}", i, j); + } else if i == j { + assert!(values[i] == values[j], "Expected {} == {}", i, j); + } else { + assert!(values[i] > values[j], "Expected {} > {}", i, j); + } + } + } + } + + #[test] + fn test_intermediate_value_all_variants_display() { + let constant = IntermediateValue::Constant(ConstExpression::scalar(42)); + let fp = IntermediateValue::Fp; + let memory = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + + assert_eq!(format!("{}", constant), "42"); + assert_eq!(format!("{}", fp), "fp"); + assert_eq!(format!("{}", memory), "m[fp + 10]"); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_all_variants_display() { + let memory = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(10), + }; + let fp = IntermediaryMemOrFpOrConstant::Fp; + let constant = IntermediaryMemOrFpOrConstant::Constant(ConstExpression::scalar(42)); + + assert_eq!(format!("{}", memory), "m[fp + 10]"); + assert_eq!(format!("{}", fp), "fp"); + assert_eq!(format!("{}", constant), "42"); + } + + #[test] + fn test_intermediate_value_label_with_complex_name() { + let label = Label::function("complex_function_name_123"); + let value = IntermediateValue::label(label.clone()); + + if let IntermediateValue::Constant(const_expr) = &value { + if let ConstExpression::Value(crate::lang::ConstantValue::Label(inner_label)) = + const_expr + { + assert_eq!(inner_label, &label); + } else { + panic!("Expected Label within ConstExpression"); + } + } else { + panic!("Expected Constant variant"); + } + } + + #[test] + fn test_intermediate_value_memory_with_zero_offset() { + let value = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(0), + }; + assert_eq!(format!("{}", value), "m[fp + 0]"); + } + + #[test] + fn test_intermediate_value_memory_with_large_offset() { + let value = IntermediateValue::MemoryAfterFp { + offset: ConstExpression::scalar(1000000), + }; + assert_eq!(format!("{}", value), "m[fp + 1000000]"); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_memory_with_zero_offset() { + let value = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(0), + }; + assert_eq!(format!("{}", value), "m[fp + 0]"); + } + + #[test] + fn test_intermediary_mem_or_fp_or_constant_memory_with_large_offset() { + let value = IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: ConstExpression::scalar(1000000), + }; + assert_eq!(format!("{}", value), "m[fp + 1000000]"); + } + + #[test] + fn test_intermediate_value_from_implementations() { + let const_expr = ConstExpression::scalar(42); + let label = Label::function("test"); + + let from_const: IntermediateValue = const_expr.clone().into(); + let from_label: IntermediateValue = label.clone().into(); + + assert_eq!(from_const, IntermediateValue::Constant(const_expr)); + + if let IntermediateValue::Constant(ConstExpression::Value( + crate::lang::ConstantValue::Label(inner_label), + )) = from_label + { + assert_eq!(inner_label, label); + } else { + panic!("Expected Label constant"); + } + } + + #[test] + fn test_value_types_exhaustive_match_coverage() { + let intermediate_value = IntermediateValue::Fp; + match intermediate_value { + IntermediateValue::Constant(_) => {} + IntermediateValue::Fp => {} + IntermediateValue::MemoryAfterFp { .. } => {} + } + + let intermediary_value = IntermediaryMemOrFpOrConstant::Fp; + match intermediary_value { + IntermediaryMemOrFpOrConstant::MemoryAfterFp { .. } => {} + IntermediaryMemOrFpOrConstant::Fp => {} + IntermediaryMemOrFpOrConstant::Constant(_) => {} + } + } +} diff --git a/crates/lean_compiler/src/simplify/simplify.rs b/crates/lean_compiler/src/simplify/simplify.rs index 6b6ededa..df52b85b 100644 --- a/crates/lean_compiler/src/simplify/simplify.rs +++ b/crates/lean_compiler/src/simplify/simplify.rs @@ -756,3 +756,821 @@ fn create_recursive_function( instructions, } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ir::HighLevelOperation, lang::*, simplify::types::*}; + + fn create_test_counters() -> Counters { + Counters::default() + } + + fn create_test_array_manager() -> ArrayManager { + ArrayManager::default() + } + + fn create_test_const_malloc() -> ConstMalloc { + ConstMalloc::default() + } + + #[test] + fn test_simplify_lines_match() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Match { + value: Expression::Value(SimpleExpr::Var("x".to_string())), + arms: vec![(0, vec![Line::Panic]), (1, vec![Line::Break])], + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + true, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + assert!(matches!(result[0], SimpleLine::Match { .. })); + + if let SimpleLine::Match { value, arms } = &result[0] { + assert_eq!(value, &SimpleExpr::Var("x".to_string())); + assert_eq!(arms.len(), 2); + assert_eq!(arms[0], vec![SimpleLine::Panic]); + assert_eq!( + arms[1], + vec![SimpleLine::FunctionRet { + return_data: vec![] + }] + ); + } + } + + #[test] + fn test_simplify_lines_assignment_value() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Assignment { + var: "x".to_string(), + value: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(42))), + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::Assignment { + var, + operation, + arg0, + arg1, + } = &result[0] + { + assert_eq!(var, &VarOrConstMallocAccess::Var("x".to_string())); + assert_eq!(operation, &HighLevelOperation::Add); + assert_eq!(arg0, &SimpleExpr::Constant(ConstExpression::scalar(42))); + assert_eq!(arg1, &SimpleExpr::zero()); + } + } + + #[test] + fn test_simplify_lines_assignment_binary() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Assignment { + var: "result".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("a".to_string()))), + operation: HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("b".to_string()))), + }, + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::Assignment { + var, + operation, + arg0, + arg1, + } = &result[0] + { + assert_eq!(var, &VarOrConstMallocAccess::Var("result".to_string())); + assert_eq!(operation, &HighLevelOperation::Add); + assert_eq!(arg0, &SimpleExpr::Var("a".to_string())); + assert_eq!(arg1, &SimpleExpr::Var("b".to_string())); + } + } + + #[test] + fn test_simplify_lines_assert_equal() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Assert(Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(5))), + })]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::Assignment { + var, + operation, + arg0, + arg1, + } = &result[0] + { + assert_eq!(var, &VarOrConstMallocAccess::Var("x".to_string())); + assert_eq!(operation, &HighLevelOperation::Add); + assert_eq!(arg0, &SimpleExpr::Constant(ConstExpression::scalar(5))); + assert_eq!(arg1, &SimpleExpr::zero()); + } + } + + #[test] + fn test_simplify_lines_assert_different() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Assert(Boolean::Different { + left: Expression::Value(SimpleExpr::Var("x".to_string())), + right: Expression::Value(SimpleExpr::Var("y".to_string())), + })]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 2); + + if let SimpleLine::Assignment { + var, + operation, + arg0, + arg1, + } = &result[0] + { + assert!(var.to_string().starts_with("@aux_var_")); + assert_eq!(operation, &HighLevelOperation::Sub); + assert_eq!(arg0, &SimpleExpr::Var("x".to_string())); + assert_eq!(arg1, &SimpleExpr::Var("y".to_string())); + } + + if let SimpleLine::IfNotZero { + condition, + then_branch, + else_branch, + } = &result[1] + { + assert!(condition.to_string().starts_with("@aux_var_")); + assert_eq!(then_branch.len(), 0); + assert_eq!(else_branch, &vec![SimpleLine::Panic]); + } + } + + #[test] + fn test_simplify_lines_function_call() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::FunctionCall { + function_name: "foo".to_string(), + args: vec![ + Expression::Value(SimpleExpr::Var("x".to_string())), + Expression::Value(SimpleExpr::Var("y".to_string())), + ], + return_data: vec!["result".to_string()], + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::FunctionCall { + function_name, + args, + return_data, + } = &result[0] + { + assert_eq!(function_name, "foo"); + assert_eq!(args.len(), 2); + assert_eq!(args[0], SimpleExpr::Var("x".to_string())); + assert_eq!(args[1], SimpleExpr::Var("y".to_string())); + assert_eq!(return_data, &vec!["result".to_string()]); + } + } + + #[test] + fn test_simplify_lines_function_ret() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::FunctionRet { + return_data: vec![ + Expression::Value(SimpleExpr::Var("x".to_string())), + Expression::Value(SimpleExpr::Var("y".to_string())), + ], + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::FunctionRet { return_data } = &result[0] { + assert_eq!(return_data.len(), 2); + assert_eq!(return_data[0], SimpleExpr::Var("x".to_string())); + assert_eq!(return_data[1], SimpleExpr::Var("y".to_string())); + } + } + + #[test] + #[should_panic(expected = "Function return inside a loop is not currently supported")] + fn test_simplify_lines_function_ret_in_loop() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::FunctionRet { + return_data: vec![Expression::Value(SimpleExpr::Var("x".to_string()))], + }]; + + simplify_lines( + &lines, + &mut counters, + &mut new_functions, + true, // in_a_loop = true + &mut array_manager, + &mut const_malloc, + ); + } + + #[test] + fn test_simplify_lines_precompile() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Precompile { + precompile: crate::precompiles::POSEIDON_16, + args: vec![ + Expression::Value(SimpleExpr::Var("input".to_string())), + Expression::Value(SimpleExpr::Var("size".to_string())), + ], + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::Precompile { precompile, args } = &result[0] { + assert_eq!(precompile, &crate::precompiles::POSEIDON_16); + assert_eq!(args.len(), 2); + assert_eq!(args[0], SimpleExpr::Var("input".to_string())); + assert_eq!(args[1], SimpleExpr::Var("size".to_string())); + } + } + + #[test] + fn test_simplify_lines_print() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Print { + line_info: "123".to_string(), + content: vec![ + Expression::Value(SimpleExpr::Var("debug1".to_string())), + Expression::Value(SimpleExpr::Var("debug2".to_string())), + ], + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::Print { line_info, content } = &result[0] { + assert_eq!(line_info, "123"); + assert_eq!(content.len(), 2); + assert_eq!(content[0], SimpleExpr::Var("debug1".to_string())); + assert_eq!(content[1], SimpleExpr::Var("debug2".to_string())); + } + } + + #[test] + fn test_simplify_lines_break() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Break]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + true, // in_a_loop = true + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::FunctionRet { return_data } = &result[0] { + assert_eq!(return_data.len(), 0); + } + } + + #[test] + #[should_panic(expected = "Break statement outside of a loop")] + fn test_simplify_lines_break_outside_loop() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Break]; + + simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, // in_a_loop = false + &mut array_manager, + &mut const_malloc, + ); + } + + #[test] + fn test_simplify_lines_decompose_bits() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::DecomposeBits { + var: "bits".to_string(), + to_decompose: vec![ + Expression::Value(SimpleExpr::Var("value1".to_string())), + Expression::Value(SimpleExpr::Var("value2".to_string())), + ], + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::DecomposeBits { + var, + to_decompose, + label, + } = &result[0] + { + assert_eq!(var, "bits"); + assert_eq!(to_decompose.len(), 2); + assert_eq!(to_decompose[0], SimpleExpr::Var("value1".to_string())); + assert_eq!(to_decompose[1], SimpleExpr::Var("value2".to_string())); + assert_eq!(label, &0); + } + } + + #[test] + fn test_simplify_lines_counter_hint() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::CounterHint { + var: "hint_var".to_string(), + }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::CounterHint { var } = &result[0] { + assert_eq!(var, "hint_var"); + } + } + + #[test] + fn test_simplify_lines_panic() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::Panic]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + assert_eq!(result[0], SimpleLine::Panic); + } + + #[test] + fn test_simplify_lines_location_report() { + let mut counters = create_test_counters(); + let mut new_functions = BTreeMap::new(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + let lines = vec![Line::LocationReport { location: 456 }]; + + let result = simplify_lines( + &lines, + &mut counters, + &mut new_functions, + false, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(result.len(), 1); + if let SimpleLine::LocationReport { location } = &result[0] { + assert_eq!(location, &456); + } + } + + #[test] + fn test_simplify_expr_value() { + let mut lines = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let const_malloc = create_test_const_malloc(); + + let expr = Expression::Value(SimpleExpr::Var("x".to_string())); + let result = simplify_expr( + &expr, + &mut lines, + &mut counters, + &mut array_manager, + &const_malloc, + ); + + assert_eq!(result, SimpleExpr::Var("x".to_string())); + assert_eq!(lines.len(), 0); + } + + #[test] + fn test_simplify_expr_binary_constants() { + let mut lines = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let const_malloc = create_test_const_malloc(); + + let expr = Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(5), + ))), + operation: HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(3), + ))), + }; + + let result = simplify_expr( + &expr, + &mut lines, + &mut counters, + &mut array_manager, + &const_malloc, + ); + + if let SimpleExpr::Constant(ConstExpression::Binary { + left, + operation, + right, + }) = result + { + assert_eq!(left.as_ref(), &ConstExpression::scalar(5)); + assert_eq!(operation, HighLevelOperation::Add); + assert_eq!(right.as_ref(), &ConstExpression::scalar(3)); + } else { + panic!("Expected constant binary expression"); + } + assert_eq!(lines.len(), 0); + } + + #[test] + fn test_simplify_expr_binary_variables() { + let mut lines = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let const_malloc = create_test_const_malloc(); + + let expr = Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + operation: HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("y".to_string()))), + }; + + let result = simplify_expr( + &expr, + &mut lines, + &mut counters, + &mut array_manager, + &const_malloc, + ); + + if let SimpleExpr::Var(var_name) = result { + assert!(var_name.starts_with("@aux_var_")); + } else { + panic!("Expected variable"); + } + assert_eq!(lines.len(), 1); + if let SimpleLine::Assignment { + var, + operation, + arg0, + arg1, + } = &lines[0] + { + assert!(var.to_string().starts_with("@aux_var_")); + assert_eq!(operation, &HighLevelOperation::Add); + assert_eq!(arg0, &SimpleExpr::Var("x".to_string())); + assert_eq!(arg1, &SimpleExpr::Var("y".to_string())); + } + } + + #[test] + fn test_simplify_expr_log2ceil() { + let mut lines = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let const_malloc = create_test_const_malloc(); + + let expr = Expression::Log2Ceil { + value: Box::new(Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(8), + ))), + }; + + let result = simplify_expr( + &expr, + &mut lines, + &mut counters, + &mut array_manager, + &const_malloc, + ); + + if let SimpleExpr::Constant(ConstExpression::Log2Ceil { value }) = result { + assert_eq!(value.as_ref(), &ConstExpression::scalar(8)); + } else { + panic!("Expected constant log2ceil expression"); + } + assert_eq!(lines.len(), 0); + } + + #[test] + fn test_handle_malloc_const_size() { + let mut res = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + handle_malloc( + &"array".to_string(), + &Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(10))), + false, + &Expression::Value(SimpleExpr::zero()), + &mut res, + &mut counters, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(res.len(), 1); + if let SimpleLine::ConstMalloc { var, size, label } = &res[0] { + assert_eq!(var, "array"); + assert_eq!(size, &ConstExpression::scalar(10)); + assert_eq!(label, &0); + } + assert!(const_malloc.map.contains_key("array")); + assert_eq!(const_malloc.counter, 1); + } + + #[test] + fn test_handle_malloc_variable_size() { + let mut res = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + handle_malloc( + &"array".to_string(), + &Expression::Value(SimpleExpr::Var("size_var".to_string())), + false, + &Expression::Value(SimpleExpr::zero()), + &mut res, + &mut counters, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(res.len(), 1); + if let SimpleLine::HintMAlloc { + var, + size, + vectorized, + vectorized_len, + } = &res[0] + { + assert_eq!(var, "array"); + assert_eq!(size, &SimpleExpr::Var("size_var".to_string())); + assert_eq!(vectorized, &false); + assert_eq!(vectorized_len, &SimpleExpr::zero()); + } + assert!(!const_malloc.map.contains_key("array")); + } + + #[test] + fn test_handle_malloc_vectorized() { + let mut res = Vec::new(); + let mut counters = create_test_counters(); + let mut array_manager = create_test_array_manager(); + let mut const_malloc = create_test_const_malloc(); + + handle_malloc( + &"vec_array".to_string(), + &Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(5))), + true, // vectorized + &Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(4))), + &mut res, + &mut counters, + &mut array_manager, + &mut const_malloc, + ); + + assert_eq!(res.len(), 1); + if let SimpleLine::HintMAlloc { + var, + size, + vectorized, + vectorized_len, + } = &res[0] + { + assert_eq!(var, "vec_array"); + assert_eq!(size, &SimpleExpr::Constant(ConstExpression::scalar(5))); + assert_eq!(vectorized, &true); + assert_eq!( + vectorized_len, + &SimpleExpr::Constant(ConstExpression::scalar(4)) + ); + } + assert!(!const_malloc.map.contains_key("vec_array")); + } + + #[test] + fn test_create_recursive_function() { + let name = "@loop_0".to_string(); + let args = vec!["i".to_string(), "x".to_string(), "y".to_string()]; + let iterator = "i".to_string(); + let end = SimpleExpr::Constant(ConstExpression::scalar(10)); + let body = vec![SimpleLine::Assignment { + var: VarOrConstMallocAccess::Var("z".to_string()), + operation: HighLevelOperation::Add, + arg0: SimpleExpr::Var("x".to_string()), + arg1: SimpleExpr::Var("y".to_string()), + }]; + let external_vars = vec!["x".to_string(), "y".to_string()]; + + let result = create_recursive_function( + name.clone(), + args.clone(), + iterator.clone(), + end, + body, + &external_vars, + ); + + assert_eq!(result.name, name); + assert_eq!(result.arguments, args); + assert_eq!(result.n_returned_vars, 0); + assert_eq!(result.instructions.len(), 2); + + // Check first instruction (comparison) + if let SimpleLine::Assignment { + var, + operation, + arg0, + arg1, + } = &result.instructions[0] + { + assert_eq!(var.to_string(), "@diff_i"); + assert_eq!(operation, &HighLevelOperation::Sub); + assert_eq!(arg0, &SimpleExpr::Var("i".to_string())); + assert_eq!(arg1, &SimpleExpr::Constant(ConstExpression::scalar(10))); + } + + // Check second instruction (conditional) + if let SimpleLine::IfNotZero { + condition, + then_branch, + else_branch, + } = &result.instructions[1] + { + assert_eq!(condition.to_string(), "@diff_i"); + assert_eq!(then_branch.len(), 4); // body + increment + recursive call + return + assert_eq!(else_branch.len(), 1); // just return + + // Check else branch (termination condition) + if let SimpleLine::FunctionRet { return_data } = &else_branch[0] { + assert_eq!(return_data.len(), 0); + } + } + } +} diff --git a/crates/lean_compiler/src/simplify/transformations.rs b/crates/lean_compiler/src/simplify/transformations.rs index 9347f684..72c84d1b 100644 --- a/crates/lean_compiler/src/simplify/transformations.rs +++ b/crates/lean_compiler/src/simplify/transformations.rs @@ -490,3 +490,642 @@ fn inline_simple_expr( } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ir::HighLevelOperation, lang::*}; + + fn create_test_program() -> Program { + Program { + functions: BTreeMap::new(), + } + } + + fn create_simple_function(name: &str, inlined: bool) -> Function { + Function { + name: name.to_string(), + arguments: vec![("x".to_string(), false), ("y".to_string(), false)], + inlined, + body: vec![Line::Assignment { + var: "result".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + operation: HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("y".to_string()))), + }, + }], + n_returned_vars: 1, + } + } + + fn create_const_function(name: &str) -> Function { + Function { + name: name.to_string(), + arguments: vec![("x".to_string(), false), ("size".to_string(), true)], + inlined: false, + body: vec![Line::Assignment { + var: "result".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + operation: HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Var("size".to_string()))), + }, + }], + n_returned_vars: 1, + } + } + + #[test] + fn test_handle_inlined_functions_simple_inline() { + let mut program = create_test_program(); + + // Create an inlined function + let inlined_func = create_simple_function("inline_add", true); + program + .functions + .insert("inline_add".to_string(), inlined_func); + + // Create a function that calls the inlined function + let caller_func = Function { + name: "caller".to_string(), + arguments: vec![("a".to_string(), false), ("b".to_string(), false)], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "inline_add".to_string(), + args: vec![ + Expression::Value(SimpleExpr::Var("a".to_string())), + Expression::Value(SimpleExpr::Var("b".to_string())), + ], + return_data: vec!["sum".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("caller".to_string(), caller_func); + + handle_inlined_functions(&mut program); + + // The inlined function should be removed + assert!(!program.functions.contains_key("inline_add")); + + // The caller function should have the inlined code + let caller = program.functions.get("caller").unwrap(); + assert_eq!(caller.body.len(), 1); + if let Line::Assignment { var, value } = &caller.body[0] { + assert!(var.starts_with("@inlined_var_")); + if let Expression::Binary { + left, + operation, + right, + } = value + { + assert_eq!(operation, &HighLevelOperation::Add); + if let Expression::Value(SimpleExpr::Var(left_var)) = left.as_ref() { + assert_eq!(left_var, "a"); + } + if let Expression::Value(SimpleExpr::Var(right_var)) = right.as_ref() { + assert_eq!(right_var, "b"); + } + } + } + } + + #[test] + fn test_handle_inlined_functions_complex_args() { + let mut program = create_test_program(); + + // Create an inlined function + let inlined_func = create_simple_function("inline_add", true); + program + .functions + .insert("inline_add".to_string(), inlined_func); + + // Create a function that calls the inlined function with complex arguments + let caller_func = Function { + name: "caller".to_string(), + arguments: vec![("a".to_string(), false), ("b".to_string(), false)], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "inline_add".to_string(), + args: vec![ + Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("a".to_string()))), + operation: HighLevelOperation::Mul, + right: Box::new(Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(2), + ))), + }, + Expression::Value(SimpleExpr::Var("b".to_string())), + ], + return_data: vec!["result".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("caller".to_string(), caller_func); + + handle_inlined_functions(&mut program); + + // Check that auxiliary variables are created for complex arguments + let caller = program.functions.get("caller").unwrap(); + assert!(caller.body.len() >= 2); // At least one aux var assignment + the inlined body + + // First instruction should be auxiliary variable assignment for complex argument + if let Line::Assignment { var, value } = &caller.body[0] { + assert!(var.starts_with("@inlined_var_")); + if let Expression::Binary { operation, .. } = value { + assert_eq!(operation, &HighLevelOperation::Mul); + } + } + } + + #[test] + fn test_handle_inlined_functions_nested_inline() { + let mut program = create_test_program(); + + // Create first inlined function + let inline1 = Function { + name: "inline1".to_string(), + arguments: vec![("x".to_string(), false)], + inlined: true, + body: vec![Line::Assignment { + var: "temp".to_string(), + value: Expression::Binary { + left: Box::new(Expression::Value(SimpleExpr::Var("x".to_string()))), + operation: HighLevelOperation::Add, + right: Box::new(Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(1), + ))), + }, + }], + n_returned_vars: 1, + }; + program.functions.insert("inline1".to_string(), inline1); + + // Create second inlined function that calls the first + let inline2 = Function { + name: "inline2".to_string(), + arguments: vec![("y".to_string(), false)], + inlined: true, + body: vec![Line::FunctionCall { + function_name: "inline1".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("y".to_string()))], + return_data: vec!["result".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("inline2".to_string(), inline2); + + // Create main function that calls inline2 + let main_func = Function { + name: "main".to_string(), + arguments: vec![("input".to_string(), false)], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "inline2".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("input".to_string()))], + return_data: vec!["output".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("main".to_string(), main_func); + + handle_inlined_functions(&mut program); + + // All inlined functions should be removed + assert!(!program.functions.contains_key("inline1")); + assert!(!program.functions.contains_key("inline2")); + + // Main function should contain the fully inlined code + let main = program.functions.get("main").unwrap(); + assert!(!main.body.is_empty()); + + // Should have the actual computation inlined + let has_add_op = main.body.iter().any(|line| { + if let Line::Assignment { value, .. } = line { + if let Expression::Binary { operation, .. } = value { + return *operation == HighLevelOperation::Add; + } + } + false + }); + assert!(has_add_op); + } + + #[test] + fn test_handle_inlined_functions_if_condition() { + let mut program = create_test_program(); + + // Create an inlined function + let inlined_func = create_simple_function("inline_add", true); + program + .functions + .insert("inline_add".to_string(), inlined_func); + + // Create a function that calls the inlined function inside an if condition + let caller_func = Function { + name: "caller".to_string(), + arguments: vec![("condition".to_string(), false)], + inlined: false, + body: vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("condition".to_string())), + right: Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(1))), + }, + then_branch: vec![Line::FunctionCall { + function_name: "inline_add".to_string(), + args: vec![ + Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(5))), + Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(3))), + ], + return_data: vec!["result".to_string()], + }], + else_branch: vec![], + }], + n_returned_vars: 0, + }; + program.functions.insert("caller".to_string(), caller_func); + + handle_inlined_functions(&mut program); + + // Check that the function call is inlined inside the if condition + let caller = program.functions.get("caller").unwrap(); + if let Line::IfCondition { then_branch, .. } = &caller.body[0] { + assert!(!then_branch.is_empty()); + // Should have inlined assignment + if let Line::Assignment { .. } = &then_branch[0] { + // Good, the function was inlined + } else { + panic!("Expected inlined assignment in then branch"); + } + } else { + panic!("Expected if condition"); + } + } + + #[test] + #[should_panic(expected = "Too many iterations processing inline functions")] + fn test_handle_inlined_functions_infinite_recursion() { + let mut program = create_test_program(); + + // Create an inlined function A that calls inlined function B + let func_a = Function { + name: "func_a".to_string(), + arguments: vec![("x".to_string(), false)], + inlined: true, + body: vec![Line::FunctionCall { + function_name: "func_b".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("x".to_string()))], + return_data: vec!["result".to_string()], + }], + n_returned_vars: 1, + }; + + // Create an inlined function B that calls inlined function A (mutual recursion) + let func_b = Function { + name: "func_b".to_string(), + arguments: vec![("x".to_string(), false)], + inlined: true, + body: vec![Line::FunctionCall { + function_name: "func_a".to_string(), + args: vec![Expression::Value(SimpleExpr::Var("x".to_string()))], + return_data: vec!["result".to_string()], + }], + n_returned_vars: 1, + }; + + program.functions.insert("func_a".to_string(), func_a); + program.functions.insert("func_b".to_string(), func_b); + + // Create main function that calls func_a (which will trigger the mutual recursion) + let main_func = Function { + name: "main".to_string(), + arguments: vec![], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "func_a".to_string(), + args: vec![Expression::Value(SimpleExpr::Constant( + ConstExpression::scalar(1), + ))], + return_data: vec!["result".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("main".to_string(), main_func); + + // This should panic due to infinite mutual recursion detection + handle_inlined_functions(&mut program); + } + + #[test] + fn test_handle_const_arguments_simple() { + let mut program = create_test_program(); + + // Create a function with constant arguments + let const_func = create_const_function("const_func"); + program + .functions + .insert("const_func".to_string(), const_func); + + // Create a function that calls const_func with constant value + let caller_func = Function { + name: "caller".to_string(), + arguments: vec![("input".to_string(), false)], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "const_func".to_string(), + args: vec![ + Expression::Value(SimpleExpr::Var("input".to_string())), + Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(10))), + ], + return_data: vec!["result".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("caller".to_string(), caller_func); + + handle_const_arguments(&mut program); + + // Original const function should be removed + assert!(!program.functions.contains_key("const_func")); + + // Should have a new specialized function + let specialized_name = "const_func_size=10"; + assert!(program.functions.contains_key(specialized_name)); + + // Caller should call the specialized function + let caller = program.functions.get("caller").unwrap(); + if let Line::FunctionCall { + function_name, + args, + .. + } = &caller.body[0] + { + assert_eq!(function_name, specialized_name); + assert_eq!(args.len(), 1); // Only non-const arguments + } + } + + #[test] + fn test_handle_const_arguments_multiple_values() { + let mut program = create_test_program(); + + let const_func = create_const_function("const_func"); + program + .functions + .insert("const_func".to_string(), const_func); + + // Create two callers with different constant values + let caller1 = Function { + name: "caller1".to_string(), + arguments: vec![("input".to_string(), false)], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "const_func".to_string(), + args: vec![ + Expression::Value(SimpleExpr::Var("input".to_string())), + Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(5))), + ], + return_data: vec!["result1".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("caller1".to_string(), caller1); + + let caller2 = Function { + name: "caller2".to_string(), + arguments: vec![("input".to_string(), false)], + inlined: false, + body: vec![Line::FunctionCall { + function_name: "const_func".to_string(), + args: vec![ + Expression::Value(SimpleExpr::Var("input".to_string())), + Expression::Value(SimpleExpr::Constant(ConstExpression::scalar(10))), + ], + return_data: vec!["result2".to_string()], + }], + n_returned_vars: 1, + }; + program.functions.insert("caller2".to_string(), caller2); + + handle_const_arguments(&mut program); + + // Should have two specialized functions + assert!(program.functions.contains_key("const_func_size=5")); + assert!(program.functions.contains_key("const_func_size=10")); + assert!(!program.functions.contains_key("const_func")); + + // Callers should reference the correct specialized functions + let caller1 = program.functions.get("caller1").unwrap(); + if let Line::FunctionCall { function_name, .. } = &caller1.body[0] { + assert_eq!(function_name, "const_func_size=5"); + } + + let caller2 = program.functions.get("caller2").unwrap(); + if let Line::FunctionCall { function_name, .. } = &caller2.body[0] { + assert_eq!(function_name, "const_func_size=10"); + } + } + + #[test] + fn test_inline_lines_assignment() { + let mut lines = vec![Line::Assignment { + var: "x".to_string(), + value: Expression::Value(SimpleExpr::Var("arg1".to_string())), + }]; + + let mut args = BTreeMap::new(); + args.insert( + "arg1".to_string(), + SimpleExpr::Constant(ConstExpression::scalar(42)), + ); + + let res = vec!["result".to_string()]; + + inline_lines(&mut lines, &args, &res, 0); + + // Variable should be renamed and argument replaced + if let Line::Assignment { var, value } = &lines[0] { + assert_eq!(var, "@inlined_var_0_x"); + if let Expression::Value(SimpleExpr::Constant(const_expr)) = value { + assert_eq!(const_expr, &ConstExpression::scalar(42)); + } else { + panic!("Expected constant value"); + } + } + } + + #[test] + fn test_inline_lines_function_return() { + let mut lines = vec![Line::FunctionRet { + return_data: vec![ + Expression::Value(SimpleExpr::Var("local_var".to_string())), + Expression::Value(SimpleExpr::Var("arg1".to_string())), + ], + }]; + + let mut args = BTreeMap::new(); + args.insert("arg1".to_string(), SimpleExpr::Var("input".to_string())); + + let res = vec!["output1".to_string(), "output2".to_string()]; + + inline_lines(&mut lines, &args, &res, 1); + + // Function return should be converted to assignments + assert_eq!(lines.len(), 2); + + if let Line::Assignment { var, value } = &lines[0] { + assert_eq!(var, "output1"); + if let Expression::Value(SimpleExpr::Var(var_name)) = value { + assert_eq!(var_name, "@inlined_var_1_local_var"); + } else { + panic!("Expected variable value in first assignment"); + } + } else { + panic!("Expected assignment in first line"); + } + + if let Line::Assignment { var, value } = &lines[1] { + assert_eq!(var, "output2"); + if let Expression::Value(SimpleExpr::Var(var_name)) = value { + assert_eq!(var_name, "input"); + } else { + panic!("Expected variable value in second assignment"); + } + } else { + panic!("Expected assignment in second line"); + } + } + + #[test] + fn test_inline_lines_if_condition() { + let mut lines = vec![Line::IfCondition { + condition: Boolean::Equal { + left: Expression::Value(SimpleExpr::Var("arg1".to_string())), + right: Expression::Value(SimpleExpr::Var("local_var".to_string())), + }, + then_branch: vec![Line::Assignment { + var: "then_var".to_string(), + value: Expression::Value(SimpleExpr::Var("arg1".to_string())), + }], + else_branch: vec![Line::Assignment { + var: "else_var".to_string(), + value: Expression::Value(SimpleExpr::Var("local_var".to_string())), + }], + }]; + + let mut args = BTreeMap::new(); + args.insert("arg1".to_string(), SimpleExpr::Var("input".to_string())); + + let res = vec![]; + + inline_lines(&mut lines, &args, &res, 2); + + if let Line::IfCondition { + condition, + then_branch, + else_branch, + } = &lines[0] + { + // Condition variables should be inlined + if let Boolean::Equal { left, right } = condition { + if let Expression::Value(SimpleExpr::Var(left_var)) = left { + assert_eq!(left_var, "input"); + } else { + panic!("Expected variable in left side of condition"); + } + if let Expression::Value(SimpleExpr::Var(right_var)) = right { + assert_eq!(right_var, "@inlined_var_2_local_var"); + } else { + panic!("Expected variable in right side of condition"); + } + } else { + panic!("Expected Equal condition"); + } + + // Variables in branches should be renamed + if let Line::Assignment { var, value } = &then_branch[0] { + assert_eq!(var, "@inlined_var_2_then_var"); + if let Expression::Value(SimpleExpr::Var(val_var)) = value { + assert_eq!(val_var, "input"); + } else { + panic!("Expected variable value in then branch assignment"); + } + } else { + panic!("Expected assignment in then branch"); + } + + if let Line::Assignment { var, value } = &else_branch[0] { + assert_eq!(var, "@inlined_var_2_else_var"); + if let Expression::Value(SimpleExpr::Var(val_var)) = value { + assert_eq!(val_var, "@inlined_var_2_local_var"); + } else { + panic!("Expected variable value in else branch assignment"); + } + } else { + panic!("Expected assignment in else branch"); + } + } else { + panic!("Expected if condition"); + } + } + + #[test] + fn test_inline_lines_for_loop() { + let mut lines = vec![Line::ForLoop { + iterator: "i".to_string(), + start: Expression::Value(SimpleExpr::Var("arg1".to_string())), + end: Expression::Value(SimpleExpr::Var("local_end".to_string())), + body: vec![Line::Assignment { + var: "loop_var".to_string(), + value: Expression::Value(SimpleExpr::Var("i".to_string())), + }], + rev: false, + unroll: false, + }]; + + let mut args = BTreeMap::new(); + args.insert( + "arg1".to_string(), + SimpleExpr::Constant(ConstExpression::scalar(0)), + ); + + let res = vec![]; + + inline_lines(&mut lines, &args, &res, 3); + + if let Line::ForLoop { + iterator, + start, + end, + body, + .. + } = &lines[0] + { + // Iterator should be renamed + assert_eq!(iterator, "@inlined_var_3_i"); + + // Start expression should use inlined argument + if let Expression::Value(SimpleExpr::Constant(const_expr)) = start { + assert_eq!(const_expr, &ConstExpression::scalar(0)); + } else { + panic!("Expected constant start value in for loop"); + } + + // End variable should be renamed + if let Expression::Value(SimpleExpr::Var(end_var)) = end { + assert_eq!(end_var, "@inlined_var_3_local_end"); + } else { + panic!("Expected variable end value in for loop"); + } + + // Body variables should be renamed + if let Line::Assignment { var, .. } = &body[0] { + assert_eq!(var, "@inlined_var_3_loop_var"); + } else { + panic!("Expected assignment in for loop body"); + } + } else { + panic!("Expected for loop"); + } + } +} From c1f2973e533e0474c416b6e2ae20d2dad380ca74 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 26 Sep 2025 23:15:10 +0200 Subject: [PATCH 11/39] small touchups --- crates/lean_compiler/src/simplify/mod.rs | 4 ++-- crates/lean_compiler/src/simplify/simplify.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/lean_compiler/src/simplify/mod.rs b/crates/lean_compiler/src/simplify/mod.rs index 41672e1e..d8f00fb6 100644 --- a/crates/lean_compiler/src/simplify/mod.rs +++ b/crates/lean_compiler/src/simplify/mod.rs @@ -32,10 +32,10 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { .arguments .iter() .map(|(v, is_const)| { - assert!(!is_const,); + assert!(!is_const); v.clone() }) - .collect::>(); + .collect(); new_functions.insert( name.clone(), SimpleFunction { diff --git a/crates/lean_compiler/src/simplify/simplify.rs b/crates/lean_compiler/src/simplify/simplify.rs index df52b85b..6d045e0a 100644 --- a/crates/lean_compiler/src/simplify/simplify.rs +++ b/crates/lean_compiler/src/simplify/simplify.rs @@ -219,7 +219,7 @@ pub fn simplify_lines( let simplified_content = content .iter() .map(|var| simplify_expr(var, &mut res, counters, array_manager, const_malloc)) - .collect::>(); + .collect(); res.push(SimpleLine::Print { line_info: line_info.clone(), content: simplified_content, @@ -255,7 +255,7 @@ pub fn simplify_lines( .map(|expr| { simplify_expr(expr, &mut res, counters, array_manager, const_malloc) }) - .collect::>(); + .collect(); let label = const_malloc.counter; const_malloc.counter += 1; const_malloc.map.insert(var.clone(), label); From 7dc7a74e5e7743ca6bb4a959c19083237d904be9 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 27 Sep 2025 00:10:30 +0200 Subject: [PATCH 12/39] more reorg and fix end to end compiler tests --- crates/lean_compiler/src/codegen/compiler.rs | 182 +++++++++++ crates/lean_compiler/src/codegen/function.rs | 88 ++++++ .../instruction.rs} | 297 +++--------------- crates/lean_compiler/src/codegen/memory.rs | 208 ++++++++++++ crates/lean_compiler/src/codegen/mod.rs | 14 + .../lean_compiler/src/codegen/validation.rs | 141 +++++++++ crates/lean_compiler/src/ir/operation.rs | 1 + crates/lean_compiler/src/ir/value.rs | 151 ++++++++- crates/lean_compiler/src/lang/ast/expr.rs | 89 ++++++ crates/lean_compiler/src/lib.rs | 29 +- .../src/parser/parsers/statement.rs | 5 +- crates/lean_compiler/src/simplify/mod.rs | 4 +- .../simplify/{simplify.rs => processor.rs} | 6 +- .../src/simplify/transformations.rs | 8 +- .../lean_compiler/src/simplify/utilities.rs | 78 ++--- 15 files changed, 981 insertions(+), 320 deletions(-) create mode 100644 crates/lean_compiler/src/codegen/compiler.rs create mode 100644 crates/lean_compiler/src/codegen/function.rs rename crates/lean_compiler/src/{b_compile_intermediate.rs => codegen/instruction.rs} (67%) create mode 100644 crates/lean_compiler/src/codegen/memory.rs create mode 100644 crates/lean_compiler/src/codegen/mod.rs create mode 100644 crates/lean_compiler/src/codegen/validation.rs rename crates/lean_compiler/src/simplify/{simplify.rs => processor.rs} (99%) diff --git a/crates/lean_compiler/src/codegen/compiler.rs b/crates/lean_compiler/src/codegen/compiler.rs new file mode 100644 index 00000000..c71970c8 --- /dev/null +++ b/crates/lean_compiler/src/codegen/compiler.rs @@ -0,0 +1,182 @@ +/// Clean, structured compiler state following best practices from production compilers. +/// +/// This module provides the core compiler structure that manages compilation state +/// in a clean, organized manner similar to LLVM's Module/Function hierarchy. +use crate::{ir::*, lang::*, simplify::*}; +use lean_vm::*; +use std::collections::BTreeMap; + +/// Main compiler state container. +/// +/// Manages all compilation state including bytecode generation, memory layout, +/// and variable tracking. Follows the exact structure from the original working compiler. +#[derive(Debug, Default)] +pub struct Compiler { + /// Generated bytecode organized by function labels. + pub bytecode: BTreeMap>, + /// Match statement bytecode blocks (each match = many bytecode blocks, each bytecode block = many instructions). + pub match_blocks: Vec>>, + /// Counter for generating unique if statement labels. + pub if_counter: usize, + /// Counter for generating unique function call labels. + pub call_counter: usize, + /// Name of the current function being compiled. + pub func_name: String, + /// Variable positions in the stack frame (var -> memory offset from fp). + pub var_positions: BTreeMap, + /// Number of function arguments for current function. + pub args_count: usize, + /// Total stack frame size for current function. + pub stack_size: usize, + /// Const malloc allocations (const_malloc_label -> start = memory offset from fp). + pub const_mallocs: BTreeMap, +} + +impl Compiler { + /// Creates a new compiler instance. + pub fn new() -> Self { + Self::default() + } + + /// Compiles a complete program to intermediate bytecode. + pub fn compile_program(&mut self, simple_program: SimpleProgram) -> Result { + let mut memory_sizes = BTreeMap::new(); + + for function in simple_program.functions.values() { + let instructions = crate::codegen::function::compile_function(function, self)?; + self.bytecode.insert(Label::function(&function.name), instructions); + memory_sizes.insert(function.name.clone(), self.stack_size); + } + + Ok(IntermediateBytecode { + bytecode: self.bytecode.clone(), + match_blocks: self.match_blocks.clone(), + memory_size_per_function: memory_sizes, + }) + } + + /// Gets the memory offset for a variable or const malloc access. + pub fn get_offset(&self, var: &VarOrConstMallocAccess) -> ConstExpression { + match var { + VarOrConstMallocAccess::Var(var) => (*self + .var_positions + .get(var) + .unwrap_or_else(|| panic!("Variable {var} not in scope"))) + .into(), + VarOrConstMallocAccess::ConstMallocAccess { + malloc_label, + offset, + } => ConstExpression::Binary { + left: Box::new( + self.const_mallocs + .get(malloc_label) + .copied() + .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) + .into(), + ), + operation: HighLevelOperation::Add, + right: Box::new(offset.clone()), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::simplify::{SimpleFunction, SimpleLine, VarOrConstMallocAccess}; + use crate::lang::SimpleExpr; + + #[test] + fn test_compiler_creation() { + let compiler = Compiler::new(); + assert!(compiler.bytecode.is_empty()); + assert!(compiler.match_blocks.is_empty()); + assert_eq!(compiler.if_counter, 0); + assert_eq!(compiler.call_counter, 0); + } + + #[test] + fn test_compiler_compile_function() { + let mut compiler = Compiler::new(); + let function = SimpleFunction { + name: "test".to_string(), + arguments: vec!["x".to_string(), "y".to_string()], + instructions: vec![SimpleLine::Assignment { + var: VarOrConstMallocAccess::Var("result".to_string()), + operation: HighLevelOperation::Add, + arg0: SimpleExpr::Var("x".to_string()), + arg1: SimpleExpr::Var("y".to_string()), + }], + n_returned_vars: 1, + }; + + let result = crate::codegen::function::compile_function(&function, &mut compiler); + assert!(result.is_ok()); + + assert_eq!(compiler.func_name, "test"); + assert_eq!(compiler.args_count, 2); + assert!(compiler.var_positions.contains_key("x")); + assert!(compiler.var_positions.contains_key("y")); + assert!(compiler.var_positions.contains_key("result")); + assert!(compiler.stack_size > 2); // At least pc + fp + args + } + + #[test] + fn test_compiler_counter_access() { + let mut compiler = Compiler::new(); + + // Test counter initialization + assert_eq!(compiler.if_counter, 0); + assert_eq!(compiler.call_counter, 0); + + // Test counter modification + compiler.if_counter = 5; + compiler.call_counter = 3; + assert_eq!(compiler.if_counter, 5); + assert_eq!(compiler.call_counter, 3); + } + + #[test] + fn test_compiler_offset_calculation() { + let mut compiler = Compiler::new(); + compiler.func_name = "test".to_string(); + compiler.args_count = 0; + compiler.stack_size = 10; + + compiler.var_positions.insert("x".to_string(), 5); + compiler.const_mallocs.insert(0, 8); + + // Test variable offset + let var_access = VarOrConstMallocAccess::Var("x".to_string()); + let offset = compiler.get_offset(&var_access); + assert_eq!(offset, ConstExpression::scalar(5)); + + // Test const malloc offset + let malloc_access = VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: 0, + offset: ConstExpression::scalar(3), + }; + let offset = compiler.get_offset(&malloc_access); + if let ConstExpression::Binary { operation, left, right } = offset { + assert_eq!(operation, HighLevelOperation::Add); + assert_eq!(left.as_ref(), &ConstExpression::scalar(8)); + assert_eq!(right.as_ref(), &ConstExpression::scalar(3)); + } else { + panic!("Expected binary expression"); + } + } + + #[test] + fn test_compiler_stack_allocation() { + let mut compiler = Compiler::new(); + compiler.func_name = "test".to_string(); + compiler.args_count = 0; + compiler.stack_size = 10; + + // Test that we can modify stack size + let initial_size = compiler.stack_size; + compiler.stack_size += 5; + assert_eq!(compiler.stack_size, initial_size + 5); + } +} \ No newline at end of file diff --git a/crates/lean_compiler/src/codegen/function.rs b/crates/lean_compiler/src/codegen/function.rs new file mode 100644 index 00000000..bc0b22ef --- /dev/null +++ b/crates/lean_compiler/src/codegen/function.rs @@ -0,0 +1,88 @@ +/// Function compilation module. +/// +/// This module handles the compilation of individual functions from the simplified +/// AST to intermediate bytecode, managing function-specific state and control flow. +use crate::{codegen::*, ir::*, lang::*, simplify::*}; +use std::collections::{BTreeSet, BTreeMap}; + +/// Compiles a single function to intermediate bytecode. +/// +/// This function handles the complete compilation of a function body, +/// including variable declaration tracking and instruction generation. +pub fn compile_function( + function: &SimpleFunction, + compiler: &mut Compiler, +) -> Result, String> { + let mut internal_vars = crate::codegen::memory::find_internal_vars(&function.instructions); + internal_vars.retain(|var| !function.arguments.contains(var)); + + // memory layout: pc, fp, args, return_vars, internal_vars + let mut stack_pos = 2; // Reserve space for pc and fp + let mut var_positions = BTreeMap::new(); + + for (i, var) in function.arguments.iter().enumerate() { + var_positions.insert(var.clone(), stack_pos + i); + } + stack_pos += function.arguments.len(); + + stack_pos += function.n_returned_vars; + + for (i, var) in internal_vars.iter().enumerate() { + var_positions.insert(var.clone(), stack_pos + i); + } + stack_pos += internal_vars.len(); + + compiler.func_name = function.name.clone(); + compiler.var_positions = var_positions; + compiler.stack_size = stack_pos; + compiler.args_count = function.arguments.len(); + + let mut declared_vars: BTreeSet = function.arguments.iter().cloned().collect(); + crate::codegen::instruction::compile_lines( + &function.instructions, + compiler, + None, + &mut declared_vars, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::SimpleExpr; + use crate::simplify::VarOrConstMallocAccess; + + #[test] + fn test_compile_simple_function() { + let function = SimpleFunction { + name: "add".to_string(), + arguments: vec!["a".to_string(), "b".to_string()], + instructions: vec![ + SimpleLine::Assignment { + var: VarOrConstMallocAccess::Var("result".to_string()), + operation: HighLevelOperation::Add, + arg0: SimpleExpr::Var("a".to_string()), + arg1: SimpleExpr::Var("b".to_string()), + }, + SimpleLine::FunctionRet { + return_data: vec![SimpleExpr::Var("result".to_string())], + }, + ], + n_returned_vars: 1, + }; + + let mut compiler = Compiler::new(); + + let result = compile_function(&function, &mut compiler); + assert!(result.is_ok()); + + let instructions = result.unwrap(); + assert!(!instructions.is_empty()); + + // Should contain at least the assignment and return + assert!(instructions.iter().any(|inst| matches!( + inst, + IntermediateInstruction::Computation { .. } + ))); + } +} \ No newline at end of file diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/codegen/instruction.rs similarity index 67% rename from crates/lean_compiler/src/b_compile_intermediate.rs rename to crates/lean_compiler/src/codegen/instruction.rs index ae9096e9..28ef3ed0 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/codegen/instruction.rs @@ -1,161 +1,17 @@ -use crate::{F, ir::*, lang::*, precompiles::*, simplify::*}; +/// Instruction compilation module. +/// +/// This module handles the compilation of individual SimpleLine instructions +/// to intermediate bytecode, managing control flow and variable state. +use crate::{codegen::*, ir::*, lang::*, simplify::*}; use lean_vm::*; -use p3_field::Field; -use std::{ - borrow::Borrow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeSet; use utils::ToUsize; -#[derive(Default)] -struct Compiler { - bytecode: BTreeMap>, - match_blocks: Vec>>, // each match = many bytecode blocks, each bytecode block = many instructions - if_counter: usize, - call_counter: usize, - func_name: String, - var_positions: BTreeMap, // var -> memory offset from fp - args_count: usize, - stack_size: usize, - const_mallocs: BTreeMap, // const_malloc_label -> start = memory offset from fp -} - -impl Compiler { - fn get_offset(&self, var: &VarOrConstMallocAccess) -> ConstExpression { - match var { - VarOrConstMallocAccess::Var(var) => (*self - .var_positions - .get(var) - .unwrap_or_else(|| panic!("Variable {var} not in scope"))) - .into(), - VarOrConstMallocAccess::ConstMallocAccess { - malloc_label, - offset, - } => ConstExpression::Binary { - left: Box::new( - self.const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, - } - } -} - -impl SimpleExpr { - fn to_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediaryMemOrFpOrConstant { - match self { - Self::Var(var) => IntermediaryMemOrFpOrConstant::MemoryAfterFp { - offset: compiler.get_offset(&var.clone().into()), - }, - Self::Constant(c) => IntermediaryMemOrFpOrConstant::Constant(c.clone()), - Self::ConstMallocAccess { - malloc_label, - offset, - } => IntermediaryMemOrFpOrConstant::MemoryAfterFp { - offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *malloc_label, - offset: offset.clone(), - }), - }, - } - } -} - -impl IntermediateValue { - fn from_simple_expr(expr: &SimpleExpr, compiler: &Compiler) -> Self { - match expr { - SimpleExpr::Var(var) => Self::MemoryAfterFp { - offset: compiler.get_offset(&var.clone().into()), - }, - SimpleExpr::Constant(c) => Self::Constant(c.clone()), - SimpleExpr::ConstMallocAccess { - malloc_label, - offset, - } => Self::MemoryAfterFp { - offset: ConstExpression::Binary { - left: Box::new( - compiler - .const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, - }, - } - } - - fn from_var_or_const_malloc_access( - var_or_const: &VarOrConstMallocAccess, - compiler: &Compiler, - ) -> Self { - Self::from_simple_expr(&var_or_const.clone().into(), compiler) - } -} - -pub fn compile_to_intermediate_bytecode( - simple_program: SimpleProgram, -) -> Result { - let mut compiler = Compiler::default(); - let mut memory_sizes = BTreeMap::new(); - - for function in simple_program.functions.values() { - let instructions = compile_function(function, &mut compiler)?; - compiler - .bytecode - .insert(Label::function(&function.name), instructions); - memory_sizes.insert(function.name.clone(), compiler.stack_size); - } - - Ok(IntermediateBytecode { - bytecode: compiler.bytecode, - match_blocks: compiler.match_blocks, - memory_size_per_function: memory_sizes, - }) -} - -fn compile_function( - function: &SimpleFunction, - compiler: &mut Compiler, -) -> Result, String> { - let mut internal_vars = find_internal_vars(&function.instructions); - - internal_vars.retain(|var| !function.arguments.contains(var)); - - // memory layout: pc, fp, args, return_vars, internal_vars - let mut stack_pos = 2; // Reserve space for pc and fp - let mut var_positions = BTreeMap::new(); - - for (i, var) in function.arguments.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += function.arguments.len(); - - stack_pos += function.n_returned_vars; - - for (i, var) in internal_vars.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += internal_vars.len(); - - compiler.func_name = function.name.clone(); - compiler.var_positions = var_positions; - compiler.stack_size = stack_pos; - compiler.args_count = function.arguments.len(); - - let mut declared_vars: BTreeSet = function.arguments.iter().cloned().collect(); - compile_lines(&function.instructions, compiler, None, &mut declared_vars) -} - -fn compile_lines( +/// Compiles a sequence of SimpleLine instructions to intermediate bytecode. +/// +/// This function is the core of the instruction compiler, handling all +/// SimpleLine variants and managing control flow between them. +pub fn compile_lines( lines: &[SimpleLine], compiler: &mut Compiler, final_jump: Option