Skip to content

Commit 4f38cb6

Browse files
authored
Merge pull request RustPython#4449 from harupy/fix-dict-spread-in-dict
Fix AST generated from a dict literal containing dict unpacking
2 parents 62aa942 + 88e3c83 commit 4f38cb6

File tree

8 files changed

+162
-56
lines changed

8 files changed

+162
-56
lines changed

compiler/ast/asdl_rs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,20 +227,25 @@ def visitConstructor(self, cons, parent, depth):
227227
if cons.fields:
228228
self.emit(f"{cons.name} {{", depth)
229229
for f in cons.fields:
230-
self.visit(f, parent, "", depth + 1)
230+
self.visit(f, parent, "", depth + 1, cons.name)
231231
self.emit("},", depth)
232232
else:
233233
self.emit(f"{cons.name},", depth)
234234

235-
def visitField(self, field, parent, vis, depth):
235+
def visitField(self, field, parent, vis, depth, constructor=None):
236236
typ = get_rust_type(field.type)
237237
fieldtype = self.typeinfo.get(field.type)
238238
if fieldtype and fieldtype.has_userdata:
239239
typ = f"{typ}<U>"
240240
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
241241
if fieldtype and fieldtype.boxed and (not (parent.product or field.seq) or field.opt):
242242
typ = f"Box<{typ}>"
243-
if field.opt:
243+
if field.opt or (
244+
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
245+
# the expression to be unpacked goes in `values` with a `None` at the corresponding
246+
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
247+
constructor == "Dict" and field.name == "keys"
248+
):
244249
typ = f"Option<{typ}>"
245250
if field.seq:
246251
typ = f"Vec<{typ}>"

compiler/ast/src/ast_gen.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ pub enum ExprKind<U = ()> {
195195
orelse: Box<Expr<U>>,
196196
},
197197
Dict {
198-
keys: Vec<Expr<U>>,
198+
keys: Vec<Option<Expr<U>>>,
199199
values: Vec<Expr<U>>,
200200
},
201201
Set {

compiler/ast/src/unparse.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ impl<'a> Unparser<'a> {
152152
let (packed, unpacked) = values.split_at(keys.len());
153153
for (k, v) in keys.iter().zip(packed) {
154154
self.p_delim(&mut first, ", ")?;
155-
write!(self, "{}: {}", *k, *v)?;
155+
if let Some(k) = k {
156+
write!(self, "{}: {}", *k, *v)?;
157+
} else {
158+
write!(self, "**{}", *v)?;
159+
}
156160
}
157161
for d in unpacked {
158162
self.p_delim(&mut first, ", ")?;

compiler/codegen/src/compile.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,18 +1986,24 @@ impl Compiler {
19861986
Ok(())
19871987
}
19881988

1989-
fn compile_dict(&mut self, keys: &[ast::Expr], values: &[ast::Expr]) -> CompileResult<()> {
1989+
fn compile_dict(
1990+
&mut self,
1991+
keys: &[Option<ast::Expr>],
1992+
values: &[ast::Expr],
1993+
) -> CompileResult<()> {
19901994
let mut size = 0;
1991-
1992-
let (packed_values, unpacked_values) = values.split_at(keys.len());
1993-
for (key, value) in keys.iter().zip(packed_values) {
1994-
self.compile_expression(key)?;
1995+
let (packed, unpacked): (Vec<_>, Vec<_>) = keys
1996+
.iter()
1997+
.zip(values.iter())
1998+
.partition(|(k, _)| k.is_some());
1999+
for (key, value) in packed {
2000+
self.compile_expression(key.as_ref().unwrap())?;
19952001
self.compile_expression(value)?;
19962002
size += 1;
19972003
}
19982004
emit!(self, Instruction::BuildMap { size });
19992005

2000-
for value in unpacked_values {
2006+
for (_, value) in unpacked {
20012007
self.compile_expression(value)?;
20022008
emit!(self, Instruction::DictUpdate);
20032009
}

compiler/codegen/src/symboltable.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -885,13 +885,10 @@ impl SymbolTableBuilder {
885885
self.scan_expression(value, ExpressionContext::Load)?;
886886
}
887887
Dict { keys, values } => {
888-
let (packed, unpacked) = values.split_at(keys.len());
889-
for (key, value) in keys.iter().zip(packed) {
890-
self.scan_expression(key, context)?;
891-
self.scan_expression(value, context)?;
892-
}
893-
for value in unpacked {
894-
// dict unpacking marker
888+
for (key, value) in keys.iter().zip(values.iter()) {
889+
if let Some(key) = key {
890+
self.scan_expression(key, context)?;
891+
}
895892
self.scan_expression(value, context)?;
896893
}
897894
}

compiler/parser/python.lalrpop

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,44 +1136,11 @@ Atom<Goal>: ast::Expr = {
11361136
}.into())
11371137
},
11381138
<location:@L> "{" <e:DictLiteralValues?> "}" <end_location:@R> => {
1139-
let pairs = e.unwrap_or_default();
1140-
1141-
let (keys, values) = match pairs.iter().position(|(k,_)| k.is_none()) {
1142-
Some(unpack_idx) => {
1143-
let mut pairs = pairs;
1144-
let (keys, mut values): (_, Vec<_>) = pairs.drain(..unpack_idx).map(|(k, v)| (*k.unwrap(), v)).unzip();
1145-
1146-
fn build_map(items: &mut Vec<(ast::Expr, ast::Expr)>) -> ast::Expr {
1147-
let location = items[0].0.location;
1148-
let end_location = items[0].0.end_location;
1149-
let (keys, values) = items.drain(..).unzip();
1150-
ast::Expr {
1151-
location,
1152-
end_location,
1153-
custom: (),
1154-
node: ast::ExprKind::Dict { keys, values }
1155-
}
1156-
}
1157-
1158-
let mut items = Vec::new();
1159-
for (key, value) in pairs.into_iter() {
1160-
if let Some(key) = key {
1161-
items.push((*key, value));
1162-
continue;
1163-
}
1164-
if !items.is_empty() {
1165-
values.push(build_map(&mut items));
1166-
}
1167-
values.push(value);
1168-
}
1169-
if !items.is_empty() {
1170-
values.push(build_map(&mut items));
1171-
}
1172-
(keys, values)
1173-
},
1174-
None => pairs.into_iter().map(|(k, v)| (*k.unwrap(), v)).unzip()
1175-
};
1176-
1139+
let (keys, values) = e
1140+
.unwrap_or_default()
1141+
.into_iter()
1142+
.map(|(k, v)| (k.map(|x| *x), v))
1143+
.unzip();
11771144
ast::Expr {
11781145
location,
11791146
end_location: Some(end_location),

compiler/parser/src/parser.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,10 @@ with (0 as a, 1 as b,): pass
309309
assert!(parse_program(source, "<test>").is_err());
310310
}
311311
}
312+
313+
#[test]
314+
fn test_dict_unpacking() {
315+
let parse_ast = parse_expression(r#"{"a": "b", **c, "d": "e"}"#, "<test>").unwrap();
316+
insta::assert_debug_snapshot!(parse_ast);
317+
}
312318
}

compiler/parser/src/snapshots/rustpython_parser__parser__tests__dict_unpacking.snap

Lines changed: 121 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)