Skip to content

Commit 8faae77

Browse files
committed
feat(code-analysis): support @return_overload flow narrowing
Add `---@return_overload` support end-to-end and use it to narrow correlated multi-return values in condition flow. - Parse/AST: - add `return_overload` token/syntax kinds - parse `---@return_overload <type>(, <type>)*` - add doc AST node and lexer/tag wiring - Analyzer/index: - handle `LuaDocTag::ReturnOverload` - store overload rows on `LuaSignature` - compute return type from overload rows (slot-wise union, variadic-aware) - Flow infra: - track `decl -> (call_expr, return_index)` mappings in binder/flow tree - clear stale mappings on reassignment - Narrowing: - add overload condition model (truthy/falsy/eq/neq) - narrow target vars using discriminant vars from the same call - support swapped equality operand order in binary condition flow - Stdlib/semantic tokens: - annotate `pcall` with `@return_overload` rows - highlight `TkTagReturnOverload` as a doc tag
1 parent 25fa503 commit 8faae77

22 files changed

Lines changed: 953 additions & 70 deletions

File tree

crates/emmylua_code_analysis/resources/std/global.lua

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,11 @@ function pairs(t) end
253253
--- boolean), which is true if the call succeeds without errors. In such case,
254254
--- `pcall` also returns all results from the call, after this first result. In
255255
--- case of any error, `pcall` returns **false** plus the error message.
256-
---@generic T, R, R1
257-
---@param f sync fun(...: T...): R1, R...
256+
---@generic T, R
257+
---@param f sync fun(...: T...): R...
258258
---@param ... T...
259-
---@return boolean, R1|string, R...
259+
---@return_overload true, R...
260+
---@return_overload false, string
260261
function pcall(f, ...) end
261262

262263
---

crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,33 @@ pub fn analyze_param(analyzer: &mut DocAnalyzer, tag: LuaDocTagParam) -> Option<
245245
}
246246

247247
pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Option<()> {
248+
let is_return_overload = tag
249+
.token_by_kind(LuaTokenKind::TkTagReturnOverload)
250+
.is_some();
248251
let description = tag
249252
.get_description()
250253
.map(|des| preprocess_description(&des.get_description_text(), None));
251254

252255
if let Some(closure) = find_owner_closure_or_report(analyzer, &tag) {
253256
let signature_id = LuaSignatureId::from_closure(analyzer.file_id, &closure);
257+
if is_return_overload {
258+
let overload_types = tag
259+
.get_types()
260+
.map(|doc_type| infer_type(analyzer, doc_type))
261+
.collect::<Vec<_>>();
262+
if overload_types.is_empty() {
263+
return Some(());
264+
}
265+
266+
let signature = analyzer
267+
.db
268+
.get_signature_index_mut()
269+
.get_or_create(signature_id);
270+
signature.return_overloads.push(overload_types);
271+
signature.resolve_return = SignatureReturnStatus::DocResolve;
272+
return Some(());
273+
}
274+
254275
let returns = tag.get_info_list();
255276
for (doc_type, name_token) in returns {
256277
let name = name_token.map(|name| name.get_name_text().to_string());

crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use emmylua_parser::{
55
};
66

77
use crate::{
8-
AnalyzeError, DiagnosticCode, FlowId, FlowNodeKind, LuaClosureId, LuaDeclId,
8+
AnalyzeError, DeclMultiReturnRef, DiagnosticCode, FlowId, FlowNodeKind, LuaClosureId,
9+
LuaDeclId,
910
compilation::analyzer::flow::{
1011
bind_analyze::{
1112
bind_block, bind_each_child, bind_node,
@@ -33,11 +34,19 @@ pub fn bind_local_stat(
3334
}
3435
}
3536

36-
for value in values {
37+
for value in &values {
3738
// If there are more values than names, we still need to bind the values
3839
bind_expr(binder, value.clone(), current);
3940
}
4041

42+
let decl_ids = local_names
43+
.iter()
44+
.map(|name| Some(LuaDeclId::new(binder.file_id, name.get_position())))
45+
.collect::<Vec<_>>();
46+
47+
// Track `local a, b = call()` slot ownership for later flow narrowing.
48+
bind_multi_return_refs(binder, &decl_ids, &values);
49+
4150
let local_flow_id = binder.create_decl(local_stat.get_position());
4251
binder.add_antecedent(local_flow_id, current);
4352
local_flow_id
@@ -88,13 +97,80 @@ pub fn bind_assign_stat(
8897
}
8998
}
9099

100+
let decl_ids = vars
101+
.iter()
102+
.map(|var| {
103+
binder
104+
.db
105+
.get_reference_index()
106+
.get_var_reference_decl(&binder.file_id, var.get_range())
107+
})
108+
.collect::<Vec<_>>();
109+
110+
// Track `a, b = call()` slot ownership for later flow narrowing.
111+
bind_multi_return_refs(binder, &decl_ids, &values);
112+
91113
let assignment_kind = FlowNodeKind::Assignment(assign_stat.to_ptr());
92114
let flow_id = binder.create_node(assignment_kind);
93115
binder.add_antecedent(flow_id, current);
94116

95117
flow_id
96118
}
97119

120+
/// Binds declaration IDs to call-return slots in assignment/local statements.
121+
///
122+
/// This lets condition flow recover which variable is the discriminant/result from the same call.
123+
fn bind_multi_return_refs(
124+
binder: &mut FlowBinder,
125+
decl_ids: &[Option<LuaDeclId>],
126+
values: &[LuaExpr],
127+
) {
128+
// Rebinding invalidates previous call-slot links.
129+
for decl_id in decl_ids.iter().flatten() {
130+
binder.decl_multi_return_ref.remove(&decl_id);
131+
}
132+
133+
if values.is_empty() {
134+
return;
135+
}
136+
137+
let min_len = decl_ids.len().min(values.len());
138+
for i in 0..min_len {
139+
let Some(decl_id) = decl_ids[i] else {
140+
continue;
141+
};
142+
let LuaExpr::CallExpr(call_expr) = &values[i] else {
143+
continue;
144+
};
145+
binder.decl_multi_return_ref.insert(
146+
decl_id,
147+
DeclMultiReturnRef {
148+
// Direct pair uses the first return slot.
149+
call_expr: call_expr.to_ptr(),
150+
return_index: 0,
151+
},
152+
);
153+
}
154+
155+
if decl_ids.len() > values.len()
156+
&& let Some(LuaExpr::CallExpr(call_expr)) = values.last()
157+
{
158+
for i in values.len()..decl_ids.len() {
159+
let Some(decl_id) = decl_ids[i] else {
160+
continue;
161+
};
162+
binder.decl_multi_return_ref.insert(
163+
decl_id,
164+
DeclMultiReturnRef {
165+
call_expr: call_expr.to_ptr(),
166+
// Remaining LHS names consume expanded trailing return slots.
167+
return_index: i - values.len() + 1,
168+
},
169+
);
170+
}
171+
}
172+
}
173+
98174
pub fn bind_call_expr_stat(
99175
binder: &mut FlowBinder,
100176
call_expr_stat: LuaCallExprStat,

crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ use rowan::TextSize;
66
use smol_str::SmolStr;
77

88
use crate::{
9-
AnalyzeError, DbIndex, FileId, FlowAntecedent, FlowId, FlowNode, FlowNodeKind, FlowTree,
10-
LuaClosureId, LuaDeclId,
9+
AnalyzeError, DbIndex, DeclMultiReturnRef, FileId, FlowAntecedent, FlowId, FlowNode,
10+
FlowNodeKind, FlowTree, LuaClosureId, LuaDeclId,
1111
};
1212

1313
#[derive(Debug)]
1414
pub struct FlowBinder<'a> {
1515
pub db: &'a mut DbIndex,
1616
pub file_id: FileId,
1717
pub decl_bind_expr_ref: HashMap<LuaDeclId, LuaAstPtr<LuaExpr>>,
18+
pub decl_multi_return_ref: HashMap<LuaDeclId, DeclMultiReturnRef>,
1819
pub start: FlowId,
1920
pub unreachable: FlowId,
2021
pub loop_label: FlowId,
@@ -36,6 +37,7 @@ impl<'a> FlowBinder<'a> {
3637
flow_nodes: Vec::new(),
3738
multiple_antecedents: Vec::new(),
3839
decl_bind_expr_ref: HashMap::new(),
40+
decl_multi_return_ref: HashMap::new(),
3941
labels: HashMap::new(),
4042
start: FlowId::default(),
4143
unreachable: FlowId::default(),
@@ -189,6 +191,7 @@ impl<'a> FlowBinder<'a> {
189191
pub fn finish(self) -> FlowTree {
190192
FlowTree::new(
191193
self.decl_bind_expr_ref,
194+
self.decl_multi_return_ref,
192195
self.flow_nodes,
193196
self.multiple_antecedents,
194197
// self.labels,

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,169 @@ _2 = a[1]
17101710
assert_eq!(e, e_expected);
17111711
}
17121712

1713+
#[test]
1714+
fn test_return_overload_narrow_after_not() {
1715+
// Boolean guard on discriminant should narrow correlated result slot.
1716+
let mut ws = VirtualWorkspace::new();
1717+
1718+
ws.def(
1719+
r#"
1720+
---@generic T, E
1721+
---@param ok boolean
1722+
---@param success T
1723+
---@param failure E
1724+
---@return boolean
1725+
---@return T|E
1726+
---@return_overload true, T
1727+
---@return_overload false, E
1728+
local function pick(ok, success, failure)
1729+
if ok then
1730+
return true, success
1731+
end
1732+
return false, failure
1733+
end
1734+
1735+
local cond ---@type boolean
1736+
local ok, result = pick(cond, 1, "error")
1737+
1738+
if not ok then
1739+
error(result)
1740+
end
1741+
1742+
a = result
1743+
"#,
1744+
);
1745+
1746+
let a = ws.expr_ty("a");
1747+
let expected = ws.ty("integer");
1748+
assert_eq!(a, expected);
1749+
}
1750+
1751+
#[test]
1752+
fn test_return_overload_narrow_with_swapped_operand_eq() {
1753+
// Equality narrowing should work when literal is on the left side.
1754+
let mut ws = VirtualWorkspace::new();
1755+
1756+
ws.def(
1757+
r#"
1758+
---@generic T, E
1759+
---@param ok boolean
1760+
---@param success T
1761+
---@param failure E
1762+
---@return "ok"|"err"
1763+
---@return T|E
1764+
---@return_overload "ok", T
1765+
---@return_overload "err", E
1766+
local function pick(ok, success, failure)
1767+
if ok then
1768+
return "ok", success
1769+
end
1770+
return "err", failure
1771+
end
1772+
1773+
local cond ---@type boolean
1774+
local tag, result = pick(cond, 1, "error")
1775+
1776+
if "err" == tag then
1777+
error(result)
1778+
end
1779+
1780+
d = result
1781+
"#,
1782+
);
1783+
1784+
let d = ws.expr_ty("d");
1785+
let expected = ws.ty("integer");
1786+
assert_eq!(d, expected);
1787+
}
1788+
1789+
#[test]
1790+
fn test_swapped_literal_eq_narrow_without_return_overload() {
1791+
// Baseline: swapped literal equality still narrows regular unions.
1792+
let mut ws = VirtualWorkspace::new();
1793+
1794+
assert!(!ws.check_code_for(
1795+
DiagnosticCode::ReturnTypeMismatch,
1796+
r#"
1797+
---@return "x"
1798+
local function test()
1799+
local a ---@type "x"|nil
1800+
if "x" == a then
1801+
return a
1802+
end
1803+
return "x"
1804+
end
1805+
"#,
1806+
));
1807+
}
1808+
1809+
#[test]
1810+
fn test_return_overload_reassign_clears_multi_return_mapping() {
1811+
// Reassignment should break call-slot correlation, preventing stale narrowing.
1812+
let mut ws = VirtualWorkspace::new();
1813+
1814+
ws.def(
1815+
r#"
1816+
---@generic T, E
1817+
---@param ok boolean
1818+
---@param success T
1819+
---@param failure E
1820+
---@return boolean
1821+
---@return T|E
1822+
---@return_overload true, T
1823+
---@return_overload false, E
1824+
local function pick(ok, success, failure)
1825+
if ok then
1826+
return true, success
1827+
end
1828+
return false, failure
1829+
end
1830+
1831+
local cond ---@type boolean
1832+
local random ---@type boolean
1833+
local ok, result = pick(cond, 1, "error")
1834+
result = random and 1 or "override"
1835+
1836+
if not ok then
1837+
error(result)
1838+
end
1839+
1840+
f = result
1841+
"#,
1842+
);
1843+
1844+
let f = ws.expr_ty("f");
1845+
let expected = ws.ty("integer|string");
1846+
assert_eq!(f, expected);
1847+
}
1848+
1849+
#[test]
1850+
fn test_pcall_return_overload_narrow_after_error_guard() {
1851+
// Stdlib `pcall` overload rows should narrow result after error branch exits.
1852+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
1853+
1854+
ws.def(
1855+
r#"
1856+
---@return integer
1857+
local function foo()
1858+
return 2
1859+
end
1860+
1861+
local ok, result = pcall(foo)
1862+
1863+
if not ok then
1864+
error(result)
1865+
end
1866+
1867+
a = result
1868+
"#,
1869+
);
1870+
1871+
let a = ws.expr_ty("a");
1872+
let expected = ws.ty("integer");
1873+
assert_eq!(a, expected);
1874+
}
1875+
17131876
#[test]
17141877
fn test_issue_868() {
17151878
let mut ws = VirtualWorkspace::new();

0 commit comments

Comments
 (0)