From f83c6b2f33a95317a055b52860589b449c8bd52e Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 25 Sep 2025 16:49:29 +0200 Subject: [PATCH 1/3] compiler: fix Non exhaustive conditions in inlined functions bug --- TODO.md | 25 -- crates/lean_compiler/src/a_simplify_lang.rs | 309 ++++++++++++++++++++ crates/lean_compiler/tests/test_compiler.rs | 91 ++++++ 3 files changed, 400 insertions(+), 25 deletions(-) diff --git a/TODO.md b/TODO.md index d627eaf5..c0a6f412 100644 --- a/TODO.md +++ b/TODO.md @@ -87,28 +87,3 @@ But we reduce proof size a lot using instead (TODO): is a valid memory access (i.e. the index is < M the memory size), but currently the DEREF instruction forces us to 'store' the result, in m[fp + i] (resp m[fp + k]). TLDR: adding a new encoding field for DEREF would save 2 memory cells / range check. If this can also increase perf in alternative scenario (other instructions for isntance), potentially we should consider it. - -## Known leanISA compiler bugs: - -### Non exhaustive conditions in inlined functions - -Currently, to inline functions we simply replace the "return" keyword by some variable assignment, i.e. -we do not properly handle conditions, we would need to add some JUMPs ... - -``` -fn works(x) inline -> 1 { - if x == 0 { - return 0; - } else { - return 1; - } -} - -fn doesnt_work(x) inline -> 1 { - if x == 0 { - return 0; // will be compiled to `res = 0`; - } // the bug: we do not JUMP here, when inlined - return 1; // will be compiled to `res = 1`; -> invalid (res = 0 and = 1 at the same time) -} -``` - diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 2a028d71..b1bdfd28 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -915,6 +915,51 @@ pub fn inline_lines( args: &BTreeMap, res: &[Var], inlining_count: usize, +) { + // First, check if this function has multiple return paths + let return_count = count_returns(lines); + + if return_count <= 1 { + // Simple case: use the original logic for single returns + inline_lines_simple(lines, args, res, inlining_count); + } else { + // Complex case: handle multiple returns with jump-based control flow + inline_lines_with_jumps(lines, args, res, inlining_count); + } +} + +fn count_returns(lines: &[Line]) -> usize { + let mut count = 0; + for line in lines { + match line { + Line::FunctionRet { .. } => count += 1, + Line::Match { arms, .. } => { + for (_, arm_lines) in arms { + count += count_returns(arm_lines); + } + } + Line::IfCondition { + then_branch, + else_branch, + .. + } => { + count += count_returns(then_branch); + count += count_returns(else_branch); + } + Line::ForLoop { body, .. } => { + count += count_returns(body); + } + _ => {} + } + } + count +} + +fn inline_lines_simple( + lines: &mut Vec, + args: &BTreeMap, + res: &[Var], + inlining_count: usize, ) { let inline_condition = |condition: &mut Boolean| { let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; @@ -974,6 +1019,8 @@ pub fn inline_lines( for expr in return_data.iter_mut() { inline_expr(expr, args, inlining_count); } + + // Simple case: direct assignment is safe since we know there's only one return lines_to_replace.push(( i, res.iter() @@ -1041,6 +1088,268 @@ pub fn inline_lines( } } +fn inline_lines_with_jumps( + lines: &mut Vec, + args: &BTreeMap, + res: &[Var], + inlining_count: usize, +) { + // Convert non-exhaustive conditions to exhaustive ones + // + // Find if-statements that contain returns and don't have else clauses + // Move all subsequent statements into the else clause + make_non_exhaustive_exhaustive(lines); + + // Now apply the standard inlining + inline_lines_simple(lines, args, res, inlining_count); +} + +fn make_non_exhaustive_exhaustive(lines: &mut Vec) { + let mut i = 0; + while i < lines.len() { + // Check if we need to restructure at this position + let should_restructure = if let Line::IfCondition { + condition: _, + then_branch, + else_branch, + } = &lines[i] + { + else_branch.is_empty() && has_return(then_branch) + } else { + false + }; + + if should_restructure { + // Extract the if condition and split the remaining statements + let mut subsequent_statements = lines.split_off(i + 1); + if let Line::IfCondition { + condition: _, + then_branch, + else_branch, + } = &mut lines[i] + { + // Recursively process the then branch + make_non_exhaustive_exhaustive(then_branch); + + // Move subsequent statements to else branch + else_branch.append(&mut subsequent_statements); + + // Recursively process the new else branch + make_non_exhaustive_exhaustive(else_branch); + } + break; // We've restructured, so we're done with this level + } else { + // Process nested conditions + if let Line::IfCondition { + condition: _, + then_branch, + else_branch, + } = &mut lines[i] + { + make_non_exhaustive_exhaustive(then_branch); + make_non_exhaustive_exhaustive(else_branch); + } + } + i += 1; + } +} + +fn has_return(lines: &[Line]) -> bool { + for line in lines { + match line { + Line::FunctionRet { .. } => return true, + Line::IfCondition { + then_branch, + else_branch, + .. + } => { + if has_return(then_branch) || has_return(else_branch) { + return true; + } + } + Line::Match { arms, .. } => { + for (_, arm_lines) in arms { + if has_return(arm_lines) { + return true; + } + } + } + Line::ForLoop { body, .. } => { + if has_return(body) { + return true; + } + } + _ => {} + } + } + false +} + +fn transform_returns_with_flag( + lines: &mut Vec, + args: &BTreeMap, + res: &[Var], + inlining_count: usize, + returned_flag: &str, +) { + let inline_condition = |condition: &mut Boolean| { + let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; + inline_expr(left, args, inlining_count); + inline_expr(right, args, inlining_count); + }; + + let inline_internal_var = |var: &mut Var| { + assert!( + !args.contains_key(var), + "Variable {var} is both an argument and assigned in the inlined function" + ); + *var = format!("@inlined_var_{inlining_count}_{var}"); + }; + + let mut lines_to_replace = vec![]; + + for (i, line) in lines.iter_mut().enumerate() { + match line { + Line::Match { value, arms } => { + inline_expr(value, args, inlining_count); + for (_, statements) in arms { + transform_returns_with_flag( + statements, + args, + res, + inlining_count, + returned_flag, + ); + } + } + Line::Assignment { var, value } => { + inline_expr(value, args, inlining_count); + inline_internal_var(var); + } + Line::IfCondition { + condition, + then_branch, + else_branch, + } => { + inline_condition(condition); + transform_returns_with_flag(then_branch, args, res, inlining_count, returned_flag); + transform_returns_with_flag(else_branch, args, res, inlining_count, returned_flag); + } + Line::FunctionCall { + args: func_args, + return_data, + .. + } => { + for arg in func_args { + inline_expr(arg, args, inlining_count); + } + for return_var in return_data { + inline_internal_var(return_var); + } + } + Line::Assert(condition) => { + inline_condition(condition); + } + Line::FunctionRet { return_data } => { + assert_eq!(return_data.len(), res.len()); + + for expr in return_data.iter_mut() { + inline_expr(expr, args, inlining_count); + } + + // For multiple returns, we need to use conditional assignment to prevent SSA violations + // Only assign to result variables if we haven't already returned + let mut new_lines = vec![]; + + // Check if we haven't returned yet + let condition = Boolean::Equal { + left: Expression::Value(SimpleExpr::Var(returned_flag.to_string())), + right: Expression::Value(SimpleExpr::scalar(0)), + }; + + // Create assignments inside an if condition + let assignments = res + .iter() + .zip(return_data) + .map(|(res_var, expr)| Line::Assignment { + var: res_var.clone(), + value: expr.clone(), + }) + .collect::>(); + + // Add assignment to set the returned flag + let mut then_branch = assignments; + then_branch.push(Line::Assignment { + var: returned_flag.to_string(), + value: Expression::Value(SimpleExpr::scalar(1)), + }); + + new_lines.push(Line::IfCondition { + condition, + then_branch, + else_branch: vec![], // Empty else branch + }); + + lines_to_replace.push((i, new_lines)); + } + Line::MAlloc { var, size, .. } => { + inline_expr(size, args, inlining_count); + inline_internal_var(var); + } + Line::Precompile { + precompile: _, + args: precompile_args, + } => { + for arg in precompile_args { + inline_expr(arg, args, inlining_count); + } + } + Line::ForLoop { + iterator, + start, + end, + body, + rev: _, + unroll: _, + } => { + transform_returns_with_flag(body, args, res, inlining_count, returned_flag); + inline_internal_var(iterator); + inline_expr(start, args, inlining_count); + inline_expr(end, args, inlining_count); + } + Line::Print { content, .. } => { + for var in content { + inline_expr(var, args, inlining_count); + } + } + Line::DecomposeBits { var, to_decompose } => { + for expr in to_decompose { + inline_expr(expr, args, inlining_count); + } + inline_internal_var(var); + } + Line::CounterHint { var } => { + inline_internal_var(var); + } + Line::ArrayAssign { + array, + index, + value, + } => { + inline_simple_expr(array, args, inlining_count); + inline_expr(index, args, inlining_count); + inline_expr(value, args, inlining_count); + } + Line::Panic | Line::Break | Line::LocationReport { .. } => {} + } + } + + // Apply the replacements + for (i, new_lines) in lines_to_replace.into_iter().rev() { + lines.splice(i..=i, new_lines); + } +} + fn vars_in_expression(expr: &Expression) -> BTreeSet { let mut vars = BTreeSet::new(); match expr { diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index a8265248..ae96e705 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -514,3 +514,94 @@ fn test_nested_inline_functions() { compile_and_run(program, &[], &[], false); } + +#[test] +fn test_inline_multiple_returns() { + let program = r#" + fn main() { + result = conditional_return(5); + print(result); + return; + } + + fn conditional_return(x) inline -> 1 { + if x == 5 { + return 100; + } else { + if x == 3 { + return 200; + } else { + return 300; + } + } + } + "#; + + compile_and_run(program, &[], &[], false); +} + +#[test] +fn test_inline_complex_multiple_returns() { + let program = r#" + fn main() { + result1 = complex_return(1, 2); + result2 = complex_return(3, 4); + print(result1); + print(result2); + return; + } + + fn complex_return(a, b) inline -> 1 { + sum = a + b; + if sum == 3 { + return 100; + } else { + if sum == 7 { + return 200; + } else { + if a == 1 { + return 150; + } else { + return 300; + } + } + } + } + "#; + + compile_and_run(program, &[], &[], false); +} + +#[test] +fn test_inline_non_exhaustive_conditions() { + let program = r#" + fn main() { + result1 = works(0); + result2 = works(1); + result3 = doesnt_work(0); + result4 = doesnt_work(1); + print(result1); + print(result2); + print(result3); + print(result4); + return; + } + + fn works(x) inline -> 1 { + if x == 0 { + return 0; + } else { + return 1; + } + } + + fn doesnt_work(x) inline -> 1 { + if x == 0 { + return 0; + } + return 1; + } + "#; + + compile_and_run(program, &[], &[], false); +} From b23a803f9d3e4af604a6fb957e12c7d7dd827770 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 25 Sep 2025 18:33:08 +0200 Subject: [PATCH 2/3] fix direct return --- crates/lean_compiler/src/a_simplify_lang.rs | 51 ++++++++++++++++++++- crates/lean_compiler/tests/test_compiler.rs | 28 +++++++++-- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index b1bdfd28..dc9ed4f2 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -1114,7 +1114,11 @@ fn make_non_exhaustive_exhaustive(lines: &mut Vec) { else_branch, } = &lines[i] { - else_branch.is_empty() && has_return(then_branch) + // Only restructure if: + // 1. The else branch is empty AND + // 2. The then branch has a direct return (not nested) AND + // 3. There are subsequent statements + else_branch.is_empty() && has_direct_return(then_branch) && i + 1 < lines.len() } else { false }; @@ -1128,7 +1132,7 @@ fn make_non_exhaustive_exhaustive(lines: &mut Vec) { else_branch, } = &mut lines[i] { - // Recursively process the then branch + // Recursively process the then branch first make_non_exhaustive_exhaustive(then_branch); // Move subsequent statements to else branch @@ -1154,6 +1158,49 @@ fn make_non_exhaustive_exhaustive(lines: &mut Vec) { } } +// Only look for direct returns, not nested ones +fn has_direct_return(lines: &[Line]) -> bool { + for line in lines { + match line { + Line::FunctionRet { .. } => return true, + _ => {} // Don't recurse into nested structures + } + } + false +} + +// More comprehensive return detection that looks for returns anywhere in nested structures +fn has_return_anywhere(lines: &[Line]) -> bool { + for line in lines { + match line { + Line::FunctionRet { .. } => return true, + Line::IfCondition { + then_branch, + else_branch, + .. + } => { + if has_return_anywhere(then_branch) || has_return_anywhere(else_branch) { + return true; + } + } + Line::Match { arms, .. } => { + for (_, arm_lines) in arms { + if has_return_anywhere(arm_lines) { + return true; + } + } + } + Line::ForLoop { body, .. } => { + if has_return_anywhere(body) { + return true; + } + } + _ => {} + } + } + false +} + fn has_return(lines: &[Line]) -> bool { for line in lines { match line { diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index ae96e705..3891fadf 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -578,8 +578,8 @@ fn test_inline_non_exhaustive_conditions() { fn main() { result1 = works(0); result2 = works(1); - result3 = doesnt_work(0); - result4 = doesnt_work(1); + result3 = should_work(0); + result4 = should_work(1); print(result1); print(result2); print(result3); @@ -595,7 +595,7 @@ fn test_inline_non_exhaustive_conditions() { } } - fn doesnt_work(x) inline -> 1 { + fn should_work(x) inline -> 1 { if x == 0 { return 0; } @@ -605,3 +605,25 @@ fn test_inline_non_exhaustive_conditions() { compile_and_run(program, &[], &[], false); } + +#[test] +fn test_inline_of_todo() { + let program = r#" + fn main() { + b = should_work(1); + print(b); + return; + } + + fn should_work(x) inline -> 1 { + if x == 1 { + if x == 0 { + return 100; + } + } + return 200; + } + "#; + + compile_and_run(program, &[], &[], false); +} From 0a6b816b3e863d2f324a3748a016c4d9c1a60cd1 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 25 Sep 2025 18:35:24 +0200 Subject: [PATCH 3/3] fix clippy --- crates/lean_compiler/src/a_simplify_lang.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index dc9ed4f2..39977053 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -1161,9 +1161,9 @@ fn make_non_exhaustive_exhaustive(lines: &mut Vec) { // Only look for direct returns, not nested ones fn has_direct_return(lines: &[Line]) -> bool { for line in lines { - match line { - Line::FunctionRet { .. } => return true, - _ => {} // Don't recurse into nested structures + // Don't recurse into nested structures + if let Line::FunctionRet { .. } = line { + return true; } } false