diff --git a/crates/emmylua_code_analysis/resources/std/global.lua b/crates/emmylua_code_analysis/resources/std/global.lua index e39c508e8..551ea6d22 100644 --- a/crates/emmylua_code_analysis/resources/std/global.lua +++ b/crates/emmylua_code_analysis/resources/std/global.lua @@ -253,10 +253,11 @@ function pairs(t) end --- boolean), which is true if the call succeeds without errors. In such case, --- `pcall` also returns all results from the call, after this first result. In --- case of any error, `pcall` returns **false** plus the error message. ----@generic T, R, R1 ----@param f sync fun(...: T...): R1, R... +---@generic T, R +---@param f sync fun(...: T...): R... ---@param ... T... ----@return boolean, R1|string, R... +---@return_overload true, R... +---@return_overload false, string function pcall(f, ...) end --- diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index cb75dc834..69fb91855 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -245,12 +245,33 @@ pub fn analyze_param(analyzer: &mut DocAnalyzer, tag: LuaDocTagParam) -> Option< } pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Option<()> { + let is_return_overload = tag + .token_by_kind(LuaTokenKind::TkTagReturnOverload) + .is_some(); let description = tag .get_description() .map(|des| preprocess_description(&des.get_description_text(), None)); if let Some(closure) = find_owner_closure_or_report(analyzer, &tag) { let signature_id = LuaSignatureId::from_closure(analyzer.file_id, &closure); + if is_return_overload { + let overload_types = tag + .get_types() + .map(|doc_type| infer_type(analyzer, doc_type)) + .collect::>(); + if overload_types.is_empty() { + return Some(()); + } + + let signature = analyzer + .db + .get_signature_index_mut() + .get_or_create(signature_id); + signature.return_overloads.push(overload_types); + signature.resolve_return = SignatureReturnStatus::DocResolve; + return Some(()); + } + let returns = tag.get_info_list(); for (doc_type, name_token) in returns { let name = name_token.map(|name| name.get_name_text().to_string()); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs index 871658f3a..9ef2a76c5 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs @@ -5,7 +5,8 @@ use emmylua_parser::{ }; use crate::{ - AnalyzeError, DiagnosticCode, FlowId, FlowNodeKind, LuaClosureId, LuaDeclId, + AnalyzeError, DeclMultiReturnRef, DeclMultiReturnRefAt, DiagnosticCode, FlowId, FlowNodeKind, + LuaClosureId, LuaDeclId, compilation::analyzer::flow::{ bind_analyze::{ bind_block, bind_each_child, bind_node, @@ -33,13 +34,25 @@ pub fn bind_local_stat( } } - for value in values { + for value in &values { // If there are more values than names, we still need to bind the values bind_expr(binder, value.clone(), current); } + let decl_ids = local_names + .iter() + .map(|name| Some(LuaDeclId::new(binder.file_id, name.get_position()))) + .collect::>(); + let local_flow_id = binder.create_decl(local_stat.get_position()); binder.add_antecedent(local_flow_id, current); + bind_multi_return_refs( + binder, + &decl_ids, + &values, + local_stat.get_position(), + local_flow_id, + ); local_flow_id } @@ -88,13 +101,72 @@ pub fn bind_assign_stat( } } + let decl_ids = vars + .iter() + .map(|var| { + binder + .db + .get_reference_index() + .get_var_reference_decl(&binder.file_id, var.get_range()) + }) + .collect::>(); + let assignment_kind = FlowNodeKind::Assignment(assign_stat.to_ptr()); let flow_id = binder.create_node(assignment_kind); binder.add_antecedent(flow_id, current); + bind_multi_return_refs( + binder, + &decl_ids, + &values, + assign_stat.get_position(), + flow_id, + ); flow_id } +fn bind_multi_return_refs( + binder: &mut FlowBinder, + decl_ids: &[Option], + values: &[LuaExpr], + position: rowan::TextSize, + flow_id: FlowId, +) { + let tail_call = values.last().and_then(|value| { + if let LuaExpr::CallExpr(call_expr) = value { + Some((values.len() - 1, call_expr.to_ptr())) + } else { + None + } + }); + + for (i, decl_id) in decl_ids.iter().enumerate() { + let Some(decl_id) = decl_id else { + continue; + }; + + let reference = tail_call.as_ref().and_then(|(last_value_idx, call_expr)| { + if i < *last_value_idx { + return None; + } + Some(DeclMultiReturnRef { + call_expr: call_expr.clone(), + return_index: i - *last_value_idx, + }) + }); + + binder + .decl_multi_return_ref + .entry(*decl_id) + .or_default() + .push(DeclMultiReturnRefAt { + position, + flow_id, + reference, + }); + } +} + pub fn bind_call_expr_stat( binder: &mut FlowBinder, call_expr_stat: LuaCallExprStat, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs index e58ed48e4..73e6e0119 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs @@ -6,8 +6,8 @@ use rowan::TextSize; use smol_str::SmolStr; use crate::{ - AnalyzeError, DbIndex, FileId, FlowAntecedent, FlowId, FlowNode, FlowNodeKind, FlowTree, - LuaClosureId, LuaDeclId, + AnalyzeError, DbIndex, DeclMultiReturnRefAt, FileId, FlowAntecedent, FlowId, FlowNode, + FlowNodeKind, FlowTree, LuaClosureId, LuaDeclId, }; #[derive(Debug)] @@ -15,6 +15,7 @@ pub struct FlowBinder<'a> { pub db: &'a mut DbIndex, pub file_id: FileId, pub decl_bind_expr_ref: HashMap>, + pub decl_multi_return_ref: HashMap>, pub start: FlowId, pub unreachable: FlowId, pub loop_label: FlowId, @@ -36,6 +37,7 @@ impl<'a> FlowBinder<'a> { flow_nodes: Vec::new(), multiple_antecedents: Vec::new(), decl_bind_expr_ref: HashMap::new(), + decl_multi_return_ref: HashMap::new(), labels: HashMap::new(), start: FlowId::default(), unreachable: FlowId::default(), @@ -189,6 +191,7 @@ impl<'a> FlowBinder<'a> { pub fn finish(self) -> FlowTree { FlowTree::new( self.decl_bind_expr_ref, + self.decl_multi_return_ref, self.flow_nodes, self.multiple_antecedents, // self.labels, diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 32228bae1..7963edbb2 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -1710,6 +1710,382 @@ _2 = a[1] assert_eq!(e, e_expected); } + #[test] + fn test_return_overload_narrow_after_not() { + // Boolean guard on discriminant should narrow correlated result slot. + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + if not ok then + error(result) + end + + a = result + "#, + ); + + let a = ws.expr_ty("a"); + let expected = ws.ty("integer"); + assert_eq!(a, expected); + } + + #[test] + fn test_return_overload_narrow_with_swapped_operand_eq() { + // Equality narrowing should work when literal is on the left side. + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return "ok"|"err" + ---@return T|E + ---@return_overload "ok", T + ---@return_overload "err", E + local function pick(ok, success, failure) + if ok then + return "ok", success + end + return "err", failure + end + + local cond ---@type boolean + local tag, result = pick(cond, 1, "error") + + if "err" == tag then + error(result) + end + + d = result + "#, + ); + + let d = ws.expr_ty("d"); + let expected = ws.ty("integer"); + assert_eq!(d, expected); + } + + #[test] + fn test_swapped_literal_eq_narrow_without_return_overload() { + // Baseline: swapped literal equality still narrows regular unions. + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@return "x" + local function test() + local a ---@type "x"|nil + if "x" == a then + return a + end + return "x" + end + "#, + )); + } + + #[test] + fn test_var_eq_var_narrow_right_operand_without_return_overload() { + // `a == b` should narrow `b` in the true branch even when `b` is on the right. + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@return "x" + local function test() + local a ---@type "x" + local b ---@type "x"|nil + if a == b then + return b + end + return "x" + end + "#, + )); + } + + #[test] + fn test_return_overload_reassign_clears_multi_return_mapping() { + // Reassignment should break call-slot correlation, preventing stale narrowing. + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local random ---@type boolean + local ok, result = pick(cond, 1, "error") + result = random and 1 or "override" + + if not ok then + error(result) + end + + f = result + "#, + ); + + let f = ws.expr_ty("f"); + let expected = ws.ty("integer|string"); + assert_eq!(f, expected); + } + + #[test] + fn test_return_overload_narrow_with_mixed_rhs_calls() { + // In mixed RHS assignment, only the trailing call's expanded slots should correlate. + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local left_ok, right_ok, right_result = pick(cond, "left-ok", "left-err"), pick(cond, 1, "right-err") + + -- `left_ok` belongs to the first call; it must not affect second-call result. + if not left_ok then + error("left failed") + end + a = right_result + + -- `right_ok` and `right_result` come from the same trailing call, so narrowing applies. + if not right_ok then + error(right_result) + end + b = right_result + "#, + ); + + // No cross-call correlation from first call discriminant. + let a = ws.expr_ty("a"); + let a_expected = ws.ty("integer|string"); + assert_eq!(a, a_expected); + + // Same-call correlation from trailing call discriminant. + let b = ws.expr_ty("b"); + let b_expected = ws.ty("integer"); + assert_eq!(b, b_expected); + } + + #[test] + fn test_pcall_return_overload_narrow_after_error_guard() { + // Stdlib `pcall` overload rows should narrow result after error branch exits. + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@return integer + local function foo() + return 2 + end + + local ok, result = pcall(foo) + + if not ok then + error(result) + end + + a = result + "#, + ); + + let a = ws.expr_ty("a"); + let expected = ws.ty("integer"); + assert_eq!(a, expected); + } + + #[test] + fn test_return_overload_late_discriminant_rebind_does_not_affect_prior_narrowing() { + // Rebinding the discriminant later must not erase narrowing at earlier flow points. + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + if not ok then + error(result) + end + + a = result + + -- Later write to `ok` should not retroactively change type at `a = result`. + ok = cond + "#, + ); + + let a = ws.expr_ty("a"); + let expected = ws.ty("integer"); + assert_eq!(a, expected); + } + + #[test] + fn test_return_overload_branch_reassign_should_not_override_join_mapping() { + // Correlation must follow reachable flow, not just latest source position. + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + + local ok, result = pick(cond, 1, "left-err") + if branch then + ok, result = pick(cond, "branch-ok", false) + end + + if not ok then + error(result) + end + + a = result + "#, + ); + + let a = ws.expr_ty("a"); + let expected = ws.ty("integer|string"); + assert_eq!(a, expected); + } + + #[test] + fn test_return_overload_join_with_noncorrelated_origin_keeps_extra_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return_overload true, integer + ---@return_overload false, string + local function pick(ok) + if ok then + return true, 1 + end + return false, "err" + end + + ---@return false + local function as_false() + return false + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond) + + -- Path A: correlated values from the same `pick` call. + -- ok=true => result=integer + -- ok=false => result=string + -- + -- Path B: `result` is from a different call than the discriminant (`ok`). + if branch then + ok, result = true, as_false() + end + + -- At this join, `result` can come from: + -- Path A: integer|string + -- Path B: false + at_join = result + + if not ok then + -- Only Path A with `ok=false` should reach here, so this is `string`. + in_error_path = result + error(result) + end + + -- Surviving paths: + -- Path A with `ok=true` => result=integer + -- Path B with `ok=true` => result=false + -- Expected: false|integer + after_guard = result + "#, + ); + + let in_error_path_ty = ws.expr_ty("in_error_path"); + assert!(ws.humanize_type(in_error_path_ty).contains("string")); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("false|integer")); + } + #[test] fn test_issue_868() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs index 4bbb6878d..70977913f 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs @@ -228,4 +228,207 @@ mod test { // so that variadic spreading continues to work as expected assert_eq!(ws.humanize_type(v_ty), "string"); } + + #[test] + fn test_higher_order_generic_return_infer_without_return_overload() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) + return true, f(...) + end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = wrap(wrap, produce) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("status"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); + } + + #[test] + fn test_higher_order_generic_return_infer_with_return_overload() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return_overload true, R... + ---@return_overload false, string + local function wrap(f, ...) + return true, f(...) + end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = wrap(wrap, produce) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("status"), ws.ty("true|false|string")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); + } + + #[test] + fn test_higher_order_return_infer_non_generic_callable_with_unresolved_remaining_arg() { + // Regression: inferring callable return from remaining args should not + // collapse the whole higher-order call to `any` when remaining args + // include unresolved field access. + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) + return true, f(...) + end + + ---@param x integer + ---@return integer + local function take_int(x) + return x + end + + ---@class Box + ---@field value integer + local box + + ok, payload = wrap(take_int, box.missing) + "#, + ); + + // `ok` comes from wrap's declared boolean return. + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + // `payload` should still track `take_int` return, not degrade to unknown/any. + assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); + } + + #[test] + fn test_higher_order_return_infer_unresolved_callable_tpl_uses_constraint() { + // Regression: unresolved callable template return should fall back to + // its generic constraint (`U: string`) when direct substitution is absent. + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R + ---@param ... T... + ---@return R + local function call_once(f, ...) + return f(...) + end + + ---@generic U: string + ---@param n integer + ---@return U + local function constrained_return(n) + end + + result = call_once(constrained_return, 1) + "#, + ); + + // `U` is unconstrained by call args but constrained in declaration. + assert_eq!(ws.expr_ty("result"), ws.ty("string")); + } + + #[test] + fn test_return_overload_variadic_tail_keeps_deep_slots() { + // `R...` in overload rows should keep trailing slots, not stop at one extra slot. + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return_overload true, R... + ---@return_overload false, string + local function wrap(f, ...) + return true, f(...) + end + + ---@param n integer + ---@return integer, string, boolean + local function produce(n) + return n, tostring(n), n > 0 + end + + ok, first, second, third = wrap(produce, 1) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("first"), ws.ty("integer|string")); + assert_eq!(ws.expr_ty("second"), ws.ty("string")); + assert_eq!(ws.expr_ty("third"), ws.ty("boolean")); + } + + #[test] + fn test_return_overload_short_row_keeps_nil_in_missing_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param ok boolean + ---@return_overload true, integer + ---@return_overload false + local function maybe(ok) + if ok then + return true, 1 + end + return false + end + + local cond ---@type boolean + status, value = maybe(cond) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("value"), ws.ty("integer|nil")); + } + + #[test] + fn test_return_overload_concrete_variadic_tail_keeps_unbounded_slots() { + // Concrete variadic tails (`integer...`) should keep producing deep slots. + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param ok boolean + ---@return_overload true, integer... + ---@return_overload false, string + local function wrap(ok) + if ok then + return true, 1, 2, 3, 4 + end + return false, "err" + end + + local cond ---@type boolean + status, first, second, third, fourth = wrap(cond) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("first"), ws.ty("integer|string")); + assert_eq!(ws.expr_ty("second"), ws.ty("integer")); + assert_eq!(ws.expr_ty("third"), ws.ty("integer")); + assert_eq!(ws.expr_ty("fourth"), ws.ty("integer")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 74d0e485a..362ca39eb 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -45,4 +45,76 @@ mod test { "# )); } + + #[test] + fn test_nested_pcall_higher_order_return_shape() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@return integer + local function f() + return 1 + end + + ok, status, payload = pcall(pcall, f) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("true|false|string")); + assert_eq!(ws.expr_ty("payload"), ws.ty("string|integer")); + } + + #[test] + fn test_nested_pcall_like_without_return_overload() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function safe_call(f, ...) + return true, f(...) + end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = safe_call(safe_call, produce) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); + } + + #[test] + fn test_nested_pcall_like_without_return_overload2() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, R, R1 + ---@param f sync fun(...: T...): R1, R... + ---@param ... T... + ---@return boolean, R1|string, R... + local function pcall_like(f, ...) end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = pcall_like(pcall_like, produce) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); + } } diff --git a/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs b/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs index 0e006f038..8c3ad0506 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs @@ -1,12 +1,14 @@ use std::collections::HashMap; -use emmylua_parser::{LuaAstPtr, LuaExpr, LuaSyntaxId}; +use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaExpr, LuaSyntaxId}; +use rowan::TextSize; -use crate::{FlowId, FlowNode, LuaDeclId}; +use crate::{FlowAntecedent, FlowId, FlowNode, LuaDeclId}; #[derive(Debug)] pub struct FlowTree { decl_bind_expr_ref: HashMap>, + decl_multi_return_ref: HashMap>, flow_nodes: Vec, multiple_antecedents: Vec>, // labels: HashMap>, @@ -16,6 +18,7 @@ pub struct FlowTree { impl FlowTree { pub fn new( decl_bind_expr_ref: HashMap>, + decl_multi_return_ref: HashMap>, flow_nodes: Vec, multiple_antecedents: Vec>, // labels: HashMap>, @@ -23,6 +26,7 @@ impl FlowTree { ) -> Self { Self { decl_bind_expr_ref, + decl_multi_return_ref, flow_nodes, multiple_antecedents, bindings, @@ -46,4 +50,80 @@ impl FlowTree { pub fn get_decl_ref_expr(&self, decl_id: &LuaDeclId) -> Option> { self.decl_bind_expr_ref.get(decl_id).cloned() } + + pub fn get_decl_multi_return_refs_at( + &self, + decl_id: &LuaDeclId, + position: TextSize, + flow_id: FlowId, + ) -> Vec { + let mut refs = Vec::new(); + self.collect_decl_multi_return_refs_at(decl_id, position, flow_id, &mut refs); + refs + } + + fn collect_decl_multi_return_refs_at( + &self, + decl_id: &LuaDeclId, + position: TextSize, + flow_id: FlowId, + refs: &mut Vec, + ) { + if let Some(at) = self.get_decl_multi_return_ref_on_flow(decl_id, position, flow_id) { + if let Some(reference) = &at.reference { + refs.push(reference.clone()); + } + return; + } + + let Some(flow_node) = self.get_flow_node(flow_id) else { + return; + }; + let Some(antecedent) = flow_node.antecedent.as_ref() else { + return; + }; + match antecedent { + FlowAntecedent::Single(next_flow_id) => { + self.collect_decl_multi_return_refs_at(decl_id, position, *next_flow_id, refs); + } + FlowAntecedent::Multiple(multi_id) => { + if let Some(multi_antecedents) = self.get_multi_antecedents(*multi_id) { + for &next_flow_id in multi_antecedents { + self.collect_decl_multi_return_refs_at( + decl_id, + position, + next_flow_id, + refs, + ); + } + } + } + } + } + + fn get_decl_multi_return_ref_on_flow( + &self, + decl_id: &LuaDeclId, + position: TextSize, + flow_id: FlowId, + ) -> Option<&DeclMultiReturnRefAt> { + self.decl_multi_return_ref + .get(decl_id)? + .iter() + .rev() + .find(|r| r.position <= position && r.flow_id == flow_id) + } +} + +#[derive(Debug, Clone)] +pub struct DeclMultiReturnRef { + pub call_expr: LuaAstPtr, + pub return_index: usize, +} + +#[derive(Debug, Clone)] +pub struct DeclMultiReturnRefAt { + pub position: TextSize, + pub flow_id: FlowId, + pub reference: Option, } diff --git a/crates/emmylua_code_analysis/src/db_index/flow/mod.rs b/crates/emmylua_code_analysis/src/db_index/flow/mod.rs index 95ef92cd4..2e54f0c81 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/mod.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use crate::{FileId, LuaSignatureId}; use emmylua_parser::{LuaAstPtr, LuaDocOpType}; pub use flow_node::*; -pub use flow_tree::FlowTree; +pub use flow_tree::{DeclMultiReturnRef, DeclMultiReturnRefAt, FlowTree}; pub use signature_cast::LuaSignatureCast; use super::traits::LuaIndex; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs index 36aa471f2..c77035e15 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs @@ -11,7 +11,10 @@ use crate::{ FileId, db_index::{LuaFunctionType, LuaType}, }; -use crate::{LuaAttributeUse, SemanticModel, VariadicType, first_param_may_not_self}; +use crate::{ + LuaAliasCallKind, LuaAliasCallType, LuaAttributeUse, SemanticModel, VariadicType, + first_param_may_not_self, +}; #[derive(Debug)] pub struct LuaSignature { @@ -20,6 +23,7 @@ pub struct LuaSignature { pub param_docs: HashMap, pub params: Vec, pub return_docs: Vec, + pub return_overloads: Vec>, pub resolve_return: SignatureReturnStatus, pub is_colon_define: bool, pub async_state: AsyncState, @@ -47,6 +51,7 @@ impl LuaSignature { param_docs: HashMap::new(), params: Vec::new(), return_docs: Vec::new(), + return_overloads: Vec::new(), resolve_return: SignatureReturnStatus::UnResolve, is_colon_define: false, async_state: AsyncState::None, @@ -111,6 +116,10 @@ impl LuaSignature { } pub fn get_return_type(&self) -> LuaType { + if !self.return_overloads.is_empty() { + return self.get_return_type_by_overloads(); + } + match self.return_docs.len() { 0 => LuaType::Nil, 1 => self.return_docs[0].type_ref.clone(), @@ -126,6 +135,133 @@ impl LuaSignature { } } + fn get_return_type_by_overloads(&self) -> LuaType { + let Some(base_max_len) = self.return_overloads.iter().map(Vec::len).max() else { + return LuaType::Nil; + }; + + if base_max_len == 0 { + return LuaType::Nil; + } + + let (has_variadic_tail, has_unbounded_variadic_tail, has_tpl_unbounded_variadic_tail) = + self.return_overloads.iter().fold( + (false, false, false), + |(has_var, has_unbounded, has_tpl_unbounded), row| { + let Some(last) = row.last() else { + return (has_var, has_unbounded, has_tpl_unbounded); + }; + let LuaType::Variadic(variadic) = last else { + return (has_var, has_unbounded, has_tpl_unbounded); + }; + + let has_unbounded_row = variadic.get_max_len().is_none(); + ( + true, + has_unbounded || has_unbounded_row, + has_tpl_unbounded || (has_unbounded_row && variadic.contain_tpl()), + ) + }, + ); + let max_len = if has_variadic_tail { + base_max_len + 1 + } else { + base_max_len + }; + + let mut types = Vec::with_capacity(max_len); + for idx in 0..max_len { + let slot_types = self + .return_overloads + .iter() + .filter_map(|row| { + Self::get_overload_row_slot_if_present(row, idx) + .or((idx < base_max_len).then_some(LuaType::Nil)) + }) + .collect(); + types.push(LuaType::from_vec(slot_types)); + } + if has_unbounded_variadic_tail + && !has_tpl_unbounded_variadic_tail + && let Some(last) = types.last_mut() + && !matches!(last, LuaType::Variadic(_)) + { + *last = LuaType::Variadic(VariadicType::Base(last.clone()).into()); + } + + if types.len() == 1 { + types.pop().unwrap_or(LuaType::Nil) + } else { + LuaType::Variadic(VariadicType::Multi(types).into()) + } + } + + pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType { + Self::get_overload_row_slot_if_present(row, idx).unwrap_or(LuaType::Nil) + } + + fn overload_row_tpl_slot( + call_kind: LuaAliasCallKind, + variadic: &std::sync::Arc, + index: i64, + ) -> LuaType { + LuaType::Call( + LuaAliasCallType::new( + call_kind, + vec![ + LuaType::Variadic(variadic.clone()), + LuaType::IntegerConst(index), + ], + ) + .into(), + ) + } + + fn get_overload_row_slot_if_present(row: &[LuaType], idx: usize) -> Option { + let row_len = row.len(); + if row_len == 0 { + return None; + } + + if idx + 1 < row_len { + return Some(row[idx].clone()); + } + + let last_idx = row_len - 1; + let last_ty = &row[last_idx]; + let offset = idx - last_idx; + if let LuaType::Variadic(variadic) = last_ty { + if let Some(slot) = variadic.get_type(offset).cloned() { + if slot.contain_tpl() { + if offset > 0 && matches!(variadic.as_ref(), VariadicType::Base(_)) { + return Some(Self::overload_row_tpl_slot( + LuaAliasCallKind::Select, + variadic, + (offset + 1) as i64, + )); + } + + return Some(Self::overload_row_tpl_slot( + LuaAliasCallKind::Index, + variadic, + offset as i64, + )); + } + return Some(slot); + } + + Some(Self::overload_row_tpl_slot( + LuaAliasCallKind::Select, + variadic, + (offset + 1) as i64, + )) + } else if offset == 0 { + Some(last_ty.clone()) + } else { + None + } + } + pub fn is_method(&self, semantic_model: &SemanticModel, owner_type: Option<&LuaType>) -> bool { if self.is_colon_define { return true; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 5c2c2889c..2c7fa1f5c 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -15,11 +15,11 @@ use crate::{ instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ - multi_param_tpl_pattern_match_multi_return, tpl_pattern_match, - variadic_tpl_pattern_match, + multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, + tpl_pattern_match, tpl_pattern_match_args, variadic_tpl_pattern_match, }, }, - infer::InferFailReason, + infer::{InferFailReason, infer_expr_list_types}, infer_expr, }, }; @@ -166,11 +166,24 @@ fn infer_generic_types_from_call( Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 Err(e) => return Err(e), }; + + if let Some(return_pattern) = + as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) + && let Some(inferred_return_type) = + infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? + { + return_type_pattern_match_target_type(context, &return_pattern, &inferred_return_type)?; + } + match (func_param_type, &arg_type) { (LuaType::Variadic(variadic), _) => { let mut arg_types = vec![]; for arg_expr in &arg_exprs[i..] { - let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; + let arg_type = match infer_expr(db, context.cache, arg_expr.clone()) { + Ok(t) => t, + Err(InferFailReason::FieldNotFound) => LuaType::Nil, + Err(e) => return Err(e), + }; arg_types.push(arg_type); } variadic_tpl_pattern_match(context, variadic, &arg_types)?; @@ -186,7 +199,21 @@ fn infer_generic_types_from_call( break; } _ => { - tpl_pattern_match(context, func_param_type, &arg_type)?; + if let Err(err) = tpl_pattern_match(context, func_param_type, &arg_type) { + let ignore_err = matches!(arg_type, LuaType::Signature(_)) + && matches!( + func_param_type, + LuaType::DocFunction(_) | LuaType::Signature(_) + ) + && matches!( + err, + InferFailReason::UnResolveSignatureReturn(_) + | InferFailReason::FieldNotFound + ); + if !ignore_err { + return Err(err); + } + } } } } @@ -202,6 +229,73 @@ fn infer_generic_types_from_call( Ok(()) } +fn infer_callable_return_from_remaining_args( + context: &mut TplContext, + callable_type: &LuaType, + arg_exprs: &[LuaExpr], +) -> Result, InferFailReason> { + if arg_exprs.is_empty() { + return Ok(None); + } + + let Some(callable) = as_doc_function_type(context.db, callable_type)? else { + return Ok(None); + }; + + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |t| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = t { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + if callable_tpls.is_empty() { + return Ok(Some(callable.get_ret().clone())); + } + + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); + let infer_return_from_callable = |substitutor: &TypeSubstitutor| { + let instantiated = instantiate_doc_function(context.db, &callable, substitutor); + match instantiated { + LuaType::DocFunction(func) => func.get_ret().clone(), + _ => callable.get_ret().clone(), + } + }; + + let call_arg_types = + match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { + Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), + Err(_) => return Ok(Some(infer_return_from_callable(&callable_substitutor))), + }; + if call_arg_types.is_empty() { + return Ok(None); + } + + let callable_param_types = callable + .get_params() + .iter() + .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) + .collect::>(); + + let mut callable_context = TplContext { + db: context.db, + cache: context.cache, + substitutor: &mut callable_substitutor, + call_expr: context.call_expr.clone(), + }; + if tpl_pattern_match_args( + &mut callable_context, + &callable_param_types, + &call_arg_types, + ) + .is_err() + { + return Ok(Some(infer_return_from_callable(&callable_substitutor))); + } + + Ok(Some(infer_return_from_callable(&callable_substitutor))) +} + pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { match self_type { LuaType::Def(id) | LuaType::Ref(id) => { @@ -268,18 +362,8 @@ fn check_expr_can_later_infer( func_param_type: &LuaType, call_arg_expr: &LuaExpr, ) -> Result { - let doc_function = match func_param_type { - LuaType::DocFunction(doc_func) => doc_func.clone(), - LuaType::Signature(sig_id) => { - let sig = context - .db - .get_signature_index() - .get(sig_id) - .ok_or(InferFailReason::None)?; - - sig.to_doc_func_type() - } - _ => return Ok(false), + let Some(doc_function) = as_doc_function_type(context.db, func_param_type)? else { + return Ok(false); }; if let LuaExpr::ClosureExpr(_) = call_arg_expr { @@ -300,3 +384,19 @@ fn check_expr_can_later_infer( Ok(variadic_count > 1) } + +fn as_doc_function_type( + db: &DbIndex, + callable_type: &LuaType, +) -> Result>, InferFailReason> { + Ok(match callable_type { + LuaType::DocFunction(doc_func) => Some(doc_func.clone()), + LuaType::Signature(sig_id) => Some( + db.get_signature_index() + .get(sig_id) + .ok_or(InferFailReason::None)? + .to_doc_func_type(), + ), + _ => None, + }) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 754974211..88e695057 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -688,7 +688,7 @@ fn param_type_list_pattern_match_type_list( Ok(()) } -fn return_type_pattern_match_target_type( +pub(crate) fn return_type_pattern_match_target_type( context: &mut TplContext, source: &LuaType, target: &LuaType, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index a96f7a131..2e8d38cd3 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -754,4 +754,33 @@ mod tests { assert!(!matches!(second, Err(InferFailReason::RecursiveInfer))); } + + #[test] + fn test_higher_order_call_with_unresolved_remaining_arg_should_not_hard_fail() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) end + + ---@generic U: string + ---@param x U + ---@return U + local function id(x) end + + ---@class Box + ---@field value integer + ---@type Box + local box + + ok, payload = wrap(id, box.missing) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("string")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index 1653f5112..9cd1de985 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -11,7 +11,10 @@ use crate::{ infer_index::infer_member_by_member_key, narrow::{ ResultTypeOrContinue, - condition_flow::{InferConditionFlow, call_flow::get_type_at_call_expr}, + condition_flow::{ + InferConditionFlow, always_literal_equal, call_flow::get_type_at_call_expr, + narrow_var_from_return_overload_condition, + }, get_single_antecedent, get_type_at_flow::get_type_at_flow, get_var_ref_type, narrow_down_type, @@ -102,7 +105,7 @@ fn try_get_at_eq_or_neq_expr( right_expr: LuaExpr, condition_flow: InferConditionFlow, ) -> Result { - let mut result_type = maybe_type_guard_binary( + if let ResultTypeOrContinue::Result(result_type) = maybe_type_guard_binary( db, tree, cache, @@ -112,12 +115,11 @@ fn try_get_at_eq_or_neq_expr( left_expr.clone(), right_expr.clone(), condition_flow, - )?; - if let ResultTypeOrContinue::Result(result_type) = result_type { + )? { return Ok(ResultTypeOrContinue::Result(result_type)); } - result_type = maybe_field_literal_eq_narrow( + if let ResultTypeOrContinue::Result(result_type) = maybe_field_literal_eq_narrow( db, tree, cache, @@ -127,12 +129,22 @@ fn try_get_at_eq_or_neq_expr( left_expr.clone(), right_expr.clone(), condition_flow, - )?; - - if let ResultTypeOrContinue::Result(result_type) = result_type { + )? { return Ok(ResultTypeOrContinue::Result(result_type)); } + let (left_expr, right_expr) = if !matches!( + left_expr, + LuaExpr::NameExpr(_) | LuaExpr::CallExpr(_) | LuaExpr::IndexExpr(_) | LuaExpr::UnaryExpr(_) + ) && matches!( + right_expr, + LuaExpr::NameExpr(_) | LuaExpr::CallExpr(_) | LuaExpr::IndexExpr(_) | LuaExpr::UnaryExpr(_) + ) { + (right_expr, left_expr) + } else { + (left_expr, right_expr) + }; + maybe_var_eq_narrow( db, tree, @@ -358,8 +370,22 @@ fn maybe_var_eq_narrow( }; if maybe_ref_id != *var_ref_id { - // If the reference declaration ID does not match, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + let Some(discriminant_decl_id) = maybe_ref_id.get_decl_id_ref() else { + return Ok(ResultTypeOrContinue::Continue); + }; + let right_expr_type = infer_expr(db, cache, right_expr)?; + return narrow_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + discriminant_decl_id, + left_name_expr.get_position(), + Some(&right_expr_type), + condition_flow, + ); } let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; @@ -564,7 +590,7 @@ fn maybe_field_literal_eq_narrow( Ok(member_type) => member_type, Err(_) => continue, // If we cannot infer the member type, skip this type }; - if const_type_eq(&member_type, &right_type) { + if always_literal_equal(&member_type, &right_type) { // If the right type matches the member type, we can narrow it opt_result = Some(i); } @@ -586,23 +612,3 @@ fn maybe_field_literal_eq_narrow( Ok(ResultTypeOrContinue::Continue) } - -fn const_type_eq(left_type: &LuaType, right_type: &LuaType) -> bool { - if left_type == right_type { - return true; - } - - match (left_type, right_type) { - ( - LuaType::StringConst(l) | LuaType::DocStringConst(l), - LuaType::StringConst(r) | LuaType::DocStringConst(r), - ) => l == r, - (LuaType::FloatConst(l), LuaType::FloatConst(r)) => l == r, - (LuaType::BooleanConst(l), LuaType::BooleanConst(r)) => l == r, - ( - LuaType::IntegerConst(l) | LuaType::DocIntegerConst(l), - LuaType::IntegerConst(r) | LuaType::DocIntegerConst(r), - ) => l == r, - _ => false, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index b922d827a..1bd8881f1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -2,10 +2,13 @@ mod binary_flow; mod call_flow; mod index_flow; +use std::collections::HashSet; + use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, + DbIndex, FlowNode, FlowTree, InferFailReason, LuaDeclId, LuaFunctionType, LuaInferCache, + LuaType, TypeOps, VariadicType, infer_expr, instantiate_func_generic, semantic::infer::{ VarRefId, narrow::{ @@ -189,6 +192,21 @@ fn get_type_at_name_ref( return Ok(ResultTypeOrContinue::Continue); }; + if let ResultTypeOrContinue::Result(result_type) = narrow_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + decl_id, + name_expr.get_position(), + None, + condition_flow, + )? { + return Ok(ResultTypeOrContinue::Result(result_type)); + } + let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { return Ok(ResultTypeOrContinue::Continue); }; @@ -209,6 +227,352 @@ fn get_type_at_name_ref( ) } +/// Narrows `var_ref_id` by correlating it with a discriminant variable from the same multi-return call. +#[allow(clippy::too_many_arguments)] +pub(super) fn narrow_var_from_return_overload_condition( + db: &DbIndex, + tree: &FlowTree, + cache: &mut LuaInferCache, + root: &LuaChunk, + var_ref_id: &VarRefId, + flow_node: &FlowNode, + discriminant_decl_id: LuaDeclId, + condition_position: rowan::TextSize, + expected_discriminant: Option<&LuaType>, + condition_flow: InferConditionFlow, +) -> Result { + // We only narrow concrete declarations. + let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { + return Ok(ResultTypeOrContinue::Continue); + }; + + // Evaluate correlation against the single antecedent flow where the branch starts. + let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let discriminant_refs = tree.get_decl_multi_return_refs_at( + &discriminant_decl_id, + condition_position, + antecedent_flow_id, + ); + if discriminant_refs.is_empty() { + return Ok(ResultTypeOrContinue::Continue); + } + let target_refs = + tree.get_decl_multi_return_refs_at(&target_decl_id, condition_position, antecedent_flow_id); + if target_refs.is_empty() { + return Ok(ResultTypeOrContinue::Continue); + } + + // Phase 1: collect rows where discriminant and target are correlated + // (same call expression origin). + let (condition_types, correlated_universe_types, correlated_target_call_expr_ids) = + collect_correlated_condition_types( + db, + cache, + root, + &discriminant_refs, + &target_refs, + expected_discriminant, + condition_flow, + )?; + + // If any target origin is uncorrelated, we must preserve it when merging. + let has_non_correlated_origin = target_refs.iter().any(|target_ref| { + !correlated_target_call_expr_ids.contains(&target_ref.call_expr.get_syntax_id()) + }); + + if condition_types.is_empty() { + return Ok(ResultTypeOrContinue::Continue); + } + + let condition_type = LuaType::from_vec(condition_types); + let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; + let narrowed_correlated = TypeOps::Intersect.apply(db, &antecedent_type, &condition_type); + if narrowed_correlated.is_never() { + return Ok(ResultTypeOrContinue::Continue); + } + + if !has_non_correlated_origin { + return Ok(result_if_changed(narrowed_correlated, &antecedent_type)); + } + + // Phase 2: merge back non-correlated origins. + let non_correlated_types = collect_non_correlated_origin_types( + db, + cache, + root, + &target_refs, + &correlated_target_call_expr_ids, + )?; + + if non_correlated_types.is_empty() { + // If those origins cannot be inferred, preserve antecedent members outside the + // correlated universe as a conservative fallback. + let correlated_universe_type = LuaType::from_vec(correlated_universe_types); + let fallback_non_correlated = + TypeOps::Remove.apply(db, &antecedent_type, &correlated_universe_type); + let merged = if fallback_non_correlated.is_never() { + narrowed_correlated.clone() + } else { + LuaType::from_vec(vec![narrowed_correlated.clone(), fallback_non_correlated]) + }; + return Ok(result_if_changed(merged, &antecedent_type)); + } + + let non_correlated_type = LuaType::from_vec(non_correlated_types); + let merged = if non_correlated_type.is_never() { + narrowed_correlated + } else { + LuaType::from_vec(vec![narrowed_correlated, non_correlated_type]) + }; + Ok(result_if_changed(merged, &antecedent_type)) +} + +/// Collects candidate target-slot types from overload rows that share call-site origin with the discriminant. +#[allow(clippy::too_many_arguments)] +fn collect_correlated_condition_types( + db: &DbIndex, + cache: &mut LuaInferCache, + root: &LuaChunk, + discriminant_refs: &[crate::DeclMultiReturnRef], + target_refs: &[crate::DeclMultiReturnRef], + expected_discriminant: Option<&LuaType>, + condition_flow: InferConditionFlow, +) -> Result< + ( + Vec, + Vec, + HashSet, + ), + InferFailReason, +> { + let mut condition_types = Vec::new(); + let mut correlated_universe_types = Vec::new(); + let mut correlated_target_call_expr_ids = HashSet::new(); + + for discriminant_ref in discriminant_refs { + let Some((call_expr, signature)) = + infer_signature_for_call_ptr(db, cache, root, &discriminant_ref.call_expr)? + else { + continue; + }; + if signature.return_overloads.is_empty() { + continue; + } + + let overload_rows = instantiate_return_overload_rows(db, cache, call_expr, signature); + let discriminant_call_expr_id = discriminant_ref.call_expr.get_syntax_id(); + + for target_ref in target_refs { + // Correlation is only valid when both values are produced by the same call site. + if target_ref.call_expr.get_syntax_id() != discriminant_call_expr_id { + continue; + } + correlated_target_call_expr_ids.insert(target_ref.call_expr.get_syntax_id()); + correlated_universe_types.extend(overload_rows.iter().map(|overload| { + crate::LuaSignature::get_overload_row_slot(overload, target_ref.return_index) + })); + condition_types.extend(overload_rows.iter().filter_map(|overload| { + let discriminant_type = crate::LuaSignature::get_overload_row_slot( + overload, + discriminant_ref.return_index, + ); + if overload_row_matches_discriminant( + db, + &discriminant_type, + expected_discriminant, + condition_flow, + ) { + return Some(crate::LuaSignature::get_overload_row_slot( + overload, + target_ref.return_index, + )); + } + None + })); + } + } + + Ok(( + condition_types, + correlated_universe_types, + correlated_target_call_expr_ids, + )) +} + +/// Collects target-slot types from origins that are not correlated with the current discriminant. +fn collect_non_correlated_origin_types( + db: &DbIndex, + cache: &mut LuaInferCache, + root: &LuaChunk, + target_refs: &[crate::DeclMultiReturnRef], + correlated_target_call_expr_ids: &HashSet, +) -> Result, InferFailReason> { + let mut non_correlated_types = Vec::new(); + + for target_ref in target_refs { + let target_call_expr_id = target_ref.call_expr.get_syntax_id(); + if correlated_target_call_expr_ids.contains(&target_call_expr_id) { + continue; + } + + let Some((call_expr, signature)) = + infer_signature_for_call_ptr(db, cache, root, &target_ref.call_expr)? + else { + continue; + }; + let overload_rows = signature_rows_at_call(db, cache, call_expr, signature); + non_correlated_types.extend(overload_rows.iter().map(|overload| { + crate::LuaSignature::get_overload_row_slot(overload, target_ref.return_index) + })); + } + + Ok(non_correlated_types) +} + +fn overload_row_matches_discriminant( + db: &DbIndex, + discriminant_type: &LuaType, + expected_discriminant: Option<&LuaType>, + condition_flow: InferConditionFlow, +) -> bool { + match expected_discriminant { + None => match condition_flow { + InferConditionFlow::TrueCondition => !discriminant_type.is_always_falsy(), + InferConditionFlow::FalseCondition => !discriminant_type.is_always_truthy(), + }, + Some(expected) => match condition_flow { + InferConditionFlow::TrueCondition => !TypeOps::Intersect + .apply(db, discriminant_type, expected) + .is_never(), + InferConditionFlow::FalseCondition => { + !always_literal_equal(discriminant_type, expected) + } + }, + } +} + +fn infer_signature_for_call_ptr<'a>( + db: &'a DbIndex, + cache: &mut LuaInferCache, + root: &LuaChunk, + call_expr_ptr: &emmylua_parser::LuaAstPtr, +) -> Result, InferFailReason> { + let Some(call_expr) = call_expr_ptr.to_node(root) else { + return Ok(None); + }; + let Some(prefix_expr) = call_expr.get_prefix_expr() else { + return Ok(None); + }; + let signature_id = match infer_expr(db, cache, prefix_expr)? { + LuaType::Signature(signature_id) => signature_id, + _ => return Ok(None), + }; + let Some(signature) = db.get_signature_index().get(&signature_id) else { + return Ok(None); + }; + + Ok(Some((call_expr, signature))) +} + +fn signature_rows_at_call( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: emmylua_parser::LuaCallExpr, + signature: &crate::LuaSignature, +) -> Vec> { + if signature.return_overloads.is_empty() { + vec![normalize_return_type_to_row(signature.get_return_type())] + } else { + instantiate_return_overload_rows(db, cache, call_expr, signature) + } +} + +/// Instantiates and normalizes each overload return row so per-slot lookups can be applied. +fn instantiate_return_overload_rows( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: emmylua_parser::LuaCallExpr, + signature: &crate::LuaSignature, +) -> Vec> { + let mut rows = Vec::with_capacity(signature.return_overloads.len()); + for overload in &signature.return_overloads { + // Convert a row into a function-style return type for generic instantiation. + let overload_return_type = match overload.len() { + 0 => LuaType::Nil, + 1 => overload[0].clone(), + _ => LuaType::Variadic(VariadicType::Multi(overload.to_vec()).into()), + }; + let instantiated_return_type = if overload_return_type.contain_tpl() { + let overload_func = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + overload_return_type.clone(), + ); + match instantiate_func_generic(db, cache, &overload_func, call_expr.clone()) { + Ok(instantiated) => instantiated.get_ret().clone(), + // Keep the original row shape if generic instantiation fails. + Err(_) => overload_return_type, + } + } else { + overload_return_type + }; + + // Normalize back to row slots for discriminant/target slot lookup. + rows.push(normalize_return_type_to_row(instantiated_return_type)); + } + + rows +} + +/// Converts a return type into a single overload-like row shape. +fn normalize_return_type_to_row(return_type: LuaType) -> Vec { + match return_type { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Multi(types) => types.clone(), + VariadicType::Base(_) => vec![LuaType::Variadic(variadic)], + }, + typ => vec![typ], + } +} + +/// Converts a candidate narrowed type into `Result` only if it actually changed. +fn result_if_changed(result_type: LuaType, antecedent_type: &LuaType) -> ResultTypeOrContinue { + if result_type == *antecedent_type { + ResultTypeOrContinue::Continue + } else { + ResultTypeOrContinue::Result(result_type) + } +} + +/// Equality helper for literal-like types used by condition narrowing. +pub(super) fn always_literal_equal(left: &LuaType, right: &LuaType) -> bool { + match (left, right) { + (LuaType::Union(union), other) => union + .into_vec() + .into_iter() + .all(|candidate| always_literal_equal(&candidate, other)), + (other, LuaType::Union(union)) => union + .into_vec() + .into_iter() + .all(|candidate| always_literal_equal(other, &candidate)), + ( + LuaType::StringConst(l) | LuaType::DocStringConst(l), + LuaType::StringConst(r) | LuaType::DocStringConst(r), + ) => l == r, + ( + LuaType::BooleanConst(l) | LuaType::DocBooleanConst(l), + LuaType::BooleanConst(r) | LuaType::DocBooleanConst(r), + ) => l == r, + ( + LuaType::IntegerConst(l) | LuaType::DocIntegerConst(l), + LuaType::IntegerConst(r) | LuaType::DocIntegerConst(r), + ) => l == r, + _ => left == right, + } +} + #[allow(clippy::too_many_arguments)] fn get_type_at_unary_flow( db: &DbIndex, diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 264767dee..8b72f38e0 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -201,6 +201,7 @@ fn build_tokens_semantic_token( | LuaTokenKind::TkTagUsing | LuaTokenKind::TkTagSource | LuaTokenKind::TkTagReturnCast + | LuaTokenKind::TkTagReturnOverload | LuaTokenKind::TkTagExport | LuaTokenKind::TkLanguage | LuaTokenKind::TkTagAttribute diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs index 81e70deb4..07bdb82a9 100644 --- a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -90,4 +90,19 @@ m.foo() Ok(()) } + + #[gtest] + fn test_return_overload_tag_is_documentation_keyword() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let data = ws.get_semantic_token_data( + r#"---@return_overload true, integer +"#, + )?; + let tokens = decode(&data); + let keyword = token_type_index(SemanticTokenType::KEYWORD); + let doc = modifier_bitset(&[SemanticTokenModifier::DOCUMENTATION]); + + verify_that!(&tokens, contains(eq(&(0, 4, 15, keyword, doc))))?; + Ok(()) + } } diff --git a/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs b/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs index cd25ed071..d046257ad 100644 --- a/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs +++ b/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs @@ -36,7 +36,8 @@ mod tests { pcall(readFile, ) "#, VirtualSignatureHelp { - target_label: "pcall(f: sync fun(path: string), path: string): boolean".to_string(), + target_label: + "pcall(f: sync fun(path: string), path: string): (true|false)".to_string(), active_signature: 0, active_parameter: 1, }, diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 5f1cad9f0..1d2759b5c 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -38,7 +38,7 @@ fn parse_tag_detail(p: &mut LuaDocParser) -> DocParseResult { LuaTokenKind::TkTagField => parse_tag_field(p), LuaTokenKind::TkTagType => parse_tag_type(p), LuaTokenKind::TkTagParam => parse_tag_param(p), - LuaTokenKind::TkTagReturn => parse_tag_return(p), + LuaTokenKind::TkTagReturn | LuaTokenKind::TkTagReturnOverload => parse_tag_return(p), LuaTokenKind::TkTagReturnCast => parse_tag_return_cast(p), // other tag LuaTokenKind::TkTagModule => parse_tag_module(p), diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index 2b2a60edd..89bdb7d8b 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -132,6 +132,18 @@ Syntax(Chunk)@0..163 assert_ast_eq!(code, result); } + #[test] + fn test_return_overload_tag() { + // Ensure lexer+parser route `@return_overload` to the dedicated syntax node/token. + let code = r#" + ---@return_overload true, integer + "#; + let tree = LuaParser::parse(code, ParserConfig::default()); + let result = format!("{:#?}", tree.get_red_root()); + assert!(result.contains("Syntax(DocTagReturn)")); + assert!(result.contains("Token(TkTagReturnOverload)")); + } + #[test] fn test_class_doc() { let code = r#" diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index f3a6a09c1..0f5e3ce1f 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -109,35 +109,36 @@ pub enum LuaTokenKind { TkTagAlias, // alias TkTagModule, // module - TkTagField, // field - TkTagType, // type - TkTagParam, // param - TkTagReturn, // return - TkTagOverload, // overload - TkTagGeneric, // generic - TkTagSee, // see - TkTagDeprecated, // deprecated - TkTagAsync, // async - TkTagCast, // cast - TkTagOther, // other - TkTagVisibility, // public private protected package - TkTagReadonly, // readonly - TkTagDiagnostic, // diagnostic - TkTagMeta, // meta - TkTagVersion, // version - TkTagAs, // as - TkTagNodiscard, // nodiscard - TkTagOperator, // operator - TkTagMapping, // mapping - TkTagNamespace, // namespace - TkTagUsing, // using - TkTagSource, // source - TkTagReturnCast, // return cast - TkTagExport, // export - TkLanguage, // language - TKTagSchema, // schema - TkTagAttribute, // attribute - TkCallGeneric, // call generic. function_name--[[@]](...) + TkTagField, // field + TkTagType, // type + TkTagParam, // param + TkTagReturn, // return + TkTagOverload, // overload + TkTagGeneric, // generic + TkTagSee, // see + TkTagDeprecated, // deprecated + TkTagAsync, // async + TkTagCast, // cast + TkTagOther, // other + TkTagVisibility, // public private protected package + TkTagReadonly, // readonly + TkTagDiagnostic, // diagnostic + TkTagMeta, // meta + TkTagVersion, // version + TkTagAs, // as + TkTagNodiscard, // nodiscard + TkTagOperator, // operator + TkTagMapping, // mapping + TkTagNamespace, // namespace + TkTagUsing, // using + TkTagSource, // source + TkTagReturnCast, // return cast + TkTagReturnOverload, // return overload + TkTagExport, // export + TkLanguage, // language + TKTagSchema, // schema + TkTagAttribute, // attribute + TkCallGeneric, // call generic. function_name--[[@]](...) TkDocOr, // | TkDocAnd, // & diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index 7accb73cc..fbf364f06 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -709,6 +709,7 @@ fn to_tag(text: &str) -> LuaTokenKind { "param" => LuaTokenKind::TkTagParam, "return" => LuaTokenKind::TkTagReturn, "return_cast" => LuaTokenKind::TkTagReturnCast, + "return_overload" => LuaTokenKind::TkTagReturnOverload, "generic" => LuaTokenKind::TkTagGeneric, "see" => LuaTokenKind::TkTagSee, "overload" => LuaTokenKind::TkTagOverload, diff --git a/docs/emmylua_doc/annotations_CN/README.md b/docs/emmylua_doc/annotations_CN/README.md index 3db3e0741..07e15ddce 100644 --- a/docs/emmylua_doc/annotations_CN/README.md +++ b/docs/emmylua_doc/annotations_CN/README.md @@ -27,6 +27,7 @@ ### 函数注解 - [`@param`](./param.md) - 参数定义 - [`@return`](./return.md) - 返回值定义 +- `@return_overload`(见 [`@return`](./return.md))- 关联返回元组 - [`@overload`](./overload.md) - 函数重载 - [`@async`](./async.md) - 异步函数标记 - [`@nodiscard`](./nodiscard.md) - 不可忽略返回值 @@ -122,6 +123,7 @@ end | `@field` | 字段定义 | `---@field name string` | | `@param` | 参数定义 | `---@param name string` | | `@return` | 返回值定义 | `---@return boolean` | +| `@return_overload` | 关联返回元组 | `---@return_overload true, T` | | `@type` | 类型声明 | `---@type string` | | `@generic` | 泛型定义 | `---@generic T` | | `@overload` | 函数重载 | `---@overload fun(x: number): number` | diff --git a/docs/emmylua_doc/annotations_CN/return.md b/docs/emmylua_doc/annotations_CN/return.md index a10da1c9a..2d922de5d 100644 --- a/docs/emmylua_doc/annotations_CN/return.md +++ b/docs/emmylua_doc/annotations_CN/return.md @@ -14,6 +14,9 @@ -- 多返回值 ---@return <类型1> [名称1] [描述1] ---@return <类型2> [名称2] [描述2] + +-- 关联返回行(每一行代表一种返回元组) +---@return_overload <类型1>, <类型2>[, <类型3>...] ``` ## 示例 @@ -174,6 +177,53 @@ for id, userName in iterateUsers() do end ``` +## 返回重载行(`@return_overload`) + +`@return_overload` 用于定义“关联”的多返回值行。每一条注解代表一种可能的返回元组。 +这对状态/结果模式(例如 `pcall` 风格代码)非常有用。 + +当多个局部变量来自同一次函数调用时,对某个返回槽位的条件判断 +(真假判断或字面量相等判断)会联动收窄同一返回行中的其他槽位。 + +```lua +---@generic T, E +---@param ok boolean +---@param success T +---@param failure E +---@return boolean +---@return T|E +---@return_overload true, T +---@return_overload false, E +local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure +end + +local cond ---@type boolean +local ok, result = pick(cond, 1, "error") + +if not ok then + error(result) -- result: string +end + +local value = result -- value: integer +``` + +`@return_overload` 同样支持泛型和可变尾部: + +```lua +---@generic T, R +---@param f fun(...: T...): R... +---@param ... T... +---@return_overload true, R... +---@return_overload false, string +local function wrap(f, ...) end +``` + +你可以保留 `@return` 作为宽泛声明,再用 `@return_overload` 提供关联敏感的推断信息。 + ## 特性 1. **多返回值支持** @@ -182,3 +232,4 @@ end 4. **函数返回值** 5. **异步返回值** 6. **条件返回值** +7. **关联返回行收窄(`@return_overload`)** diff --git a/docs/emmylua_doc/annotations_EN/README.md b/docs/emmylua_doc/annotations_EN/README.md index e7f7434dc..355d04b7c 100644 --- a/docs/emmylua_doc/annotations_EN/README.md +++ b/docs/emmylua_doc/annotations_EN/README.md @@ -29,6 +29,7 @@ The following notation symbols are used in annotation syntax descriptions: ### Function Annotations - [`@param`](./param.md) - Parameter definition - [`@return`](./return.md) - Return value definition +- `@return_overload` (see [`@return`](./return.md)) - Correlated return tuples - [`@overload`](./overload.md) - Function overload - [`@async`](./async.md) - Async function marker - [`@nodiscard`](./nodiscard.md) - Non-discardable return value @@ -124,6 +125,7 @@ end | `@field` | Field definition | `---@field name string` | | `@param` | Parameter definition | `---@param name string` | | `@return` | Return value definition | `---@return boolean` | +| `@return_overload` | Correlated return tuples | `---@return_overload true, T` | | `@type` | Type declaration | `---@type string` | | `@generic` | Generic definition | `---@generic T` | | `@overload` | Function overload | `---@overload fun(x: number): number` | diff --git a/docs/emmylua_doc/annotations_EN/return.md b/docs/emmylua_doc/annotations_EN/return.md index 6eaeb6366..4a2e0c66f 100644 --- a/docs/emmylua_doc/annotations_EN/return.md +++ b/docs/emmylua_doc/annotations_EN/return.md @@ -14,6 +14,9 @@ Define return value types and description information for functions. -- Multiple return values ---@return [name1] [description1] ---@return [name2] [description2] + +-- Correlated return rows (one row per possible return tuple) +---@return_overload , [, ...] ``` ## Examples @@ -174,6 +177,53 @@ for id, userName in iterateUsers() do end ``` +## Return Overload Rows (`@return_overload`) + +`@return_overload` defines correlated multi-return rows. Each annotation line represents one possible return tuple. +This is useful for status/result APIs (for example `pcall`-style code). + +When multiple local variables are assigned from the same call, condition checks on one return slot +(truthy/falsy checks or literal equality checks) narrow correlated slots from the same row. + +```lua +---@generic T, E +---@param ok boolean +---@param success T +---@param failure E +---@return boolean +---@return T|E +---@return_overload true, T +---@return_overload false, E +local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure +end + +local cond ---@type boolean +local ok, result = pick(cond, 1, "error") + +if not ok then + error(result) -- result: string +end + +local value = result -- value: integer +``` + +`@return_overload` also supports generic and variadic tails: + +```lua +---@generic T, R +---@param f fun(...: T...): R... +---@param ... T... +---@return_overload true, R... +---@return_overload false, string +local function wrap(f, ...) end +``` + +You can keep `@return` as the broad declaration and add `@return_overload` rows for correlation-sensitive inference. + ## Features 1. **Multiple return value support** @@ -182,3 +232,4 @@ end 4. **Function return values** 5. **Async return values** 6. **Conditional return values** +7. **Correlated return row narrowing (`@return_overload`)**