Skip to content

Commit 178cd3e

Browse files
feat(compiler): CPython parity for min/max key=, enumerate start=, next default, iter sentinel; #[repr(u8)] on OpCode
1 parent dec5d2d commit 178cd3e

7 files changed

Lines changed: 104 additions & 41 deletions

File tree

compiler/src/modules/parser/literals.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ impl<'src, I: Iterator<Item = Token>> Parser<'src, I> {
379379
return true;
380380
}
381381

382-
// min()/max() take a `default=` keyword, so keep positional and keyword counts distinct.
383-
if name == "min" || name == "max" {
384-
let op = if name == "min" { OpCode::CallMin } else { OpCode::CallMax };
382+
// min()/max() (`default=`/`key=`) and enumerate() (`start=`) take keywords, so keep positional and keyword counts distinct.
383+
if name == "min" || name == "max" || name == "enumerate" {
384+
let op = match name.as_str() { "min" => OpCode::CallMin, "max" => OpCode::CallMax, _ => OpCode::CallEnumerate };
385385
let (pos, kw) = self.parse_args();
386386
self.chunk.emit(op, ((kw & 0xFF) << 8) | (pos & 0xFF));
387387
self.chunk.record_call_pos(call_pos);

compiler/src/modules/parser/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) const MAX_EXPR_DEPTH: usize = 200;
88
pub(crate) const MAX_INSTRUCTIONS: usize = 65_535;
99

1010
#[derive(Debug, Clone, Copy, PartialEq)]
11+
#[repr(u8)] // <256 variants; guarantees a 1-byte tag for stable bytecode / transmute / jump-table dispatch
1112
pub enum OpCode {
1213
LoadConst, LoadName, StoreName, Call, PopTop, ReturnValue, BuildString, CallPrint, CallLen,
1314
FormatValue, CallAbs, Minus, CallStr, CallInt, CallRange, Phi, CallChr, CallType, MakeFunction,

compiler/src/modules/vm/builtins/numeric.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -215,18 +215,19 @@ impl<'a> VM<'a> {
215215
self.push(v); Ok(())
216216
}
217217

218-
pub fn call_min(&mut self, op: u16) -> Result<(), VmErr> { self.call_minmax(op, true) }
219-
pub fn call_max(&mut self, op: u16) -> Result<(), VmErr> { self.call_minmax(op, false) }
218+
pub fn call_min(&mut self, op: u16, chunk: &crate::modules::parser::SSAChunk, slots: &mut [Val]) -> Result<(), VmErr> { self.call_minmax(op, true, chunk, slots) }
219+
pub fn call_max(&mut self, op: u16, chunk: &crate::modules::parser::SSAChunk, slots: &mut [Val]) -> Result<(), VmErr> { self.call_minmax(op, false, chunk, slots) }
220220

221-
fn call_minmax(&mut self, op: u16, is_min: bool) -> Result<(), VmErr> {
221+
fn call_minmax(&mut self, op: u16, is_min: bool, chunk: &crate::modules::parser::SSAChunk, slots: &mut [Val]) -> Result<(), VmErr> {
222222
let (positional, kw_flat, _np, _nk) = self.parse_call_args(op)?;
223-
// Optional `default=` (returned when a single iterable is empty).
223+
// Optional `default=` (returned when a single iterable is empty) and `key=` (compare by key(x)).
224224
let mut default: Option<Val> = None;
225+
let mut key: Option<Val> = None;
225226
for pair in kw_flat.chunks_exact(2) {
226-
if matches!(self.heap.try_get(pair[0]), Some(HeapObj::Str(s)) if s == "default") {
227-
default = Some(pair[1]);
228-
} else {
229-
return Err(cold_type("min()/max() got an unexpected keyword argument"));
227+
match self.heap.try_get(pair[0]) {
228+
Some(HeapObj::Str(s)) if s == "default" => default = Some(pair[1]),
229+
Some(HeapObj::Str(s)) if s == "key" => { if !pair[1].is_none() { key = Some(pair[1]); } }
230+
_ => return Err(cold_type("min()/max() got an unexpected keyword argument")),
230231
}
231232
}
232233
// One arg iterable; many args are values.
@@ -235,11 +236,21 @@ impl<'a> VM<'a> {
235236
if items.is_empty() {
236237
return match default { Some(d) => { self.push(d); Ok(()) }, None => Err(cold_value(label)) };
237238
}
238-
let m = items[1..].iter().try_fold(items[0], |m, &x| {
239-
let (l, r) = if is_min { (x, m) } else { (m, x) };
240-
self.lt_vals(l, r).map(|lt| if lt { x } else { m })
241-
})?;
242-
self.push(m); Ok(())
239+
// Without a key, compare elements directly; with one, compare key(x) but return the winning element.
240+
let keys: Vec<Val> = match key {
241+
None => items.clone(),
242+
Some(k) => {
243+
let mut ks = Vec::with_capacity(items.len());
244+
for &x in &items { self.push(k); self.push(x); self.exec_call(1, chunk, slots)?; ks.push(self.pop()?); }
245+
ks
246+
}
247+
};
248+
let mut best = 0;
249+
for i in 1..items.len() {
250+
let (l, r) = if is_min { (keys[i], keys[best]) } else { (keys[best], keys[i]) };
251+
if self.lt_vals(l, r)? { best = i; }
252+
}
253+
self.push(items[best]); Ok(())
243254
}
244255

245256
pub fn call_sum(&mut self, op: u16) -> Result<(), VmErr> {

compiler/src/modules/vm/builtins/sequence.rs

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,28 @@ impl<'a> VM<'a> {
221221
self.alloc_and_push_list(items)
222222
}
223223

224-
pub fn call_enumerate(&mut self) -> Result<(), VmErr> {
225-
let o = self.pop()?;
226-
let src = self.extract_iter(o, false)?;
224+
pub fn call_enumerate(&mut self, op: u16) -> Result<(), VmErr> {
225+
let (positional, kw_flat, _np, _nk) = self.parse_call_args(op)?;
226+
if positional.is_empty() || positional.len() > 2 {
227+
return Err(cold_type("enumerate() takes 1 or 2 positional arguments"));
228+
}
229+
// `start` is positional (`enumerate(xs, 5)`) or keyword (`enumerate(xs, start=5)`); default 0.
230+
let mut start = if positional.len() == 2 { positional[1] } else { Val::int(0) };
231+
for pair in kw_flat.chunks_exact(2) {
232+
match self.heap.try_get(pair[0]) {
233+
Some(HeapObj::Str(s)) if s == "start" => start = pair[1],
234+
_ => return Err(cold_type("enumerate() got an unexpected keyword argument")),
235+
}
236+
}
237+
let start = match self.as_i128(start) {
238+
Some(n) => n,
239+
None => return Err(cold_type("enumerate() start must be an integer")),
240+
};
241+
let src = self.extract_iter(positional[0], false)?;
227242
let mut pairs: Vec<Val> = Vec::with_capacity(src.len());
228243
for (i, x) in src.into_iter().enumerate() {
229-
let t = self.heap.alloc(HeapObj::Tuple(vec![Val::int(i as i64), x]))?;
244+
let idx = self.int_to_val(start.checked_add(i as i128))?;
245+
let t = self.heap.alloc(HeapObj::Tuple(vec![idx, x]))?;
230246
pairs.push(t);
231247
}
232248
self.alloc_and_push_list(pairs)
@@ -335,21 +351,43 @@ impl<'a> VM<'a> {
335351
self.extract_iter(o, true)
336352
}
337353

338-
/* `iter(x)`, eager flatten into a fresh List drained front-to-back by `next()`. Original isn't touched. Mirrors the universal ABI's `Op::Iter`. */
339-
pub fn call_iter(&mut self) -> Result<(), VmErr> {
354+
/* `iter(x)`, eager flatten into a fresh List drained front-to-back by `next()`. Original isn't touched. Mirrors the universal ABI's `Op::Iter`. The 2-arg form `iter(callable, sentinel)` calls `callable()` until it returns `sentinel`, eagerly. */
355+
pub fn call_iter(&mut self, argc: u16, chunk: &crate::modules::parser::SSAChunk, slots: &mut [Val]) -> Result<(), VmErr> {
356+
if argc == 2 {
357+
let sentinel = self.pop()?;
358+
let callable = self.pop()?;
359+
let mut items: Vec<Val> = Vec::new();
360+
loop {
361+
self.charge_step()?; // bound the call loop against the op budget
362+
self.push(callable);
363+
self.exec_call(0, chunk, slots)?;
364+
let v = self.pop()?;
365+
if eq_vals_with_heap(v, sentinel, &self.heap) { break; }
366+
if items.len() >= self.heap.limit() { return Err(cold_heap()); }
367+
items.push(v);
368+
}
369+
return self.alloc_and_push_list(items);
370+
}
371+
if argc != 1 { return Err(cold_type("iter() takes 1 or 2 arguments")); }
340372
let o = self.pop()?;
341373
let items = self.iter_to_vec_general(o)?;
342374
self.alloc_and_push_list(items)
343375
}
344376

345-
pub fn call_next(&mut self) -> Result<(), VmErr> {
377+
pub fn call_next(&mut self, argc: u16) -> Result<(), VmErr> {
378+
if argc == 0 || argc > 2 { return Err(cold_type("next() takes 1 or 2 arguments")); }
379+
// `next(it, default)`: the 2nd arg is returned instead of raising StopIteration on exhaustion.
380+
let default = if argc == 2 { Some(self.pop()?) } else { None };
346381
let o = self.pop()?;
347382
if !o.is_heap() { return Err(cold_type("next() requires an iterator")); }
348383
// List path mirrors the ABI's IterNext op so script `next()` and host `Op::IterNext` match.
349384
if let HeapObj::List(rc) = self.heap.get(o) {
350385
let rc = rc.clone();
351386
let mut v = rc.borrow_mut();
352-
if v.is_empty() { return Err(VmErr::Raised(s!("StopIteration"))); }
387+
if v.is_empty() {
388+
drop(v);
389+
return match default { Some(d) => { self.push(d); Ok(()) }, None => Err(VmErr::Raised(s!("StopIteration"))) };
390+
}
353391
let item = v.remove(0);
354392
drop(v);
355393
self.push(item);
@@ -365,7 +403,7 @@ impl<'a> VM<'a> {
365403
self.push(result);
366404
Ok(())
367405
} else {
368-
Err(VmErr::Runtime("StopIteration"))
406+
match default { Some(d) => { self.push(d); Ok(()) }, None => Err(VmErr::Runtime("StopIteration")) }
369407
}
370408
}
371409

compiler/src/modules/vm/handlers/function.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ impl<'a> VM<'a> {
4747
OpCode::CallSorted => self.call_sorted(false),
4848
OpCode::CallList => self.call_list(chunk, slots),
4949
OpCode::CallTuple => self.call_tuple(chunk, slots),
50-
OpCode::CallEnumerate => self.call_enumerate(),
50+
OpCode::CallEnumerate => self.call_enumerate(operand),
5151
OpCode::CallIsInstance => self.call_isinstance(),
5252
OpCode::CallRange => self.call_range(operand),
5353
OpCode::CallRound => self.call_round(operand),
54-
OpCode::CallMin => self.call_min(operand),
55-
OpCode::CallMax => self.call_max(operand),
54+
OpCode::CallMin => self.call_min(operand, chunk, slots),
55+
OpCode::CallMax => self.call_max(operand, chunk, slots),
5656
OpCode::CallSum => self.call_sum(operand),
5757
OpCode::CallZip => self.call_zip(operand),
5858
OpCode::CallDict => self.call_dict(operand),
@@ -667,9 +667,10 @@ impl<'a> VM<'a> {
667667
let expected: Option<u16> = match id {
668668
Input | Receive => Some(0),
669669
Len | Abs | Str | Int | Float | Bool | Type | Chr | Ord
670-
| Sorted | Enumerate | List | Tuple | Bin | Oct | Hex
671-
| Repr | Reversed | Callable | Id | Hash | Next | Sleep
672-
| Iter => Some(1),
670+
| Sorted | List | Tuple | Bin | Oct | Hex
671+
| Repr | Reversed | Callable | Id | Hash | Sleep => Some(1),
672+
// Enumerate (start), Next (default), Iter (sentinel) accept an optional 2nd arg; validated in their handlers.
673+
Enumerate | Next | Iter => None,
673674
Divmod | IsInstance | IsSubclass | HasAttr | Map | Filter | DelAttr => Some(2),
674675
SetAttr => Some(3),
675676
WithTimeout => Some(2),
@@ -702,8 +703,8 @@ impl<'a> VM<'a> {
702703
}
703704
Range => self.call_range(argc),
704705
Round => self.call_round(argc),
705-
Min => self.call_min(argc),
706-
Max => self.call_max(argc),
706+
Min => self.call_min(argc, chunk, slots),
707+
Max => self.call_max(argc, chunk, slots),
707708
Sum => self.call_sum(argc),
708709
Zip => self.call_zip(argc),
709710
Dict => self.call_dict(argc),
@@ -725,7 +726,7 @@ impl<'a> VM<'a> {
725726
Chr => self.call_chr(),
726727
Ord => self.call_ord(),
727728
Sorted => self.call_sorted_with_key(sort_key, sort_reverse, chunk, slots),
728-
Enumerate => self.call_enumerate(),
729+
Enumerate => self.call_enumerate(argc),
729730
List => self.call_list(chunk, slots),
730731
Tuple => self.call_tuple(chunk, slots),
731732
Bin => self.call_bin(),
@@ -740,14 +741,14 @@ impl<'a> VM<'a> {
740741
IsInstance => self.call_isinstance(),
741742
IsSubclass => self.call_issubclass(),
742743
HasAttr => self.call_hasattr(),
743-
Next => self.call_next(),
744+
Next => self.call_next(argc),
744745
Run => self.call_run(argc),
745746
Sleep => self.call_sleep(),
746747
Frame => self.call_frame(),
747748
Receive => self.call_receive(),
748749
Map => self.call_map(chunk, slots),
749750
Filter => self.call_filter(chunk, slots),
750-
Iter => self.call_iter(),
751+
Iter => self.call_iter(argc, chunk, slots),
751752
Bytes => self.call_bytes(argc),
752753
Slice => self.call_slice(argc),
753754
Vars => self.call_vars(),

compiler/tests/cases/vm.json

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2283,5 +2283,17 @@
22832283
{"src": "def add(xs):\n xs += [9]\nys = [1]\nadd(ys)\nprint(ys)", "output": ["[1, 9]"]},
22842284
{"src": "i = 5\ni += 3\nprint(i)", "output": ["8"]},
22852285
{"src": "s = \"a\"\nt = s\ns += \"b\"\nprint(s, t)", "output": ["ab a"]},
2286-
{"src": "acc = []\nfor i in range(3):\n acc += [i]\nprint(acc)", "output": ["[0, 1, 2]"]}
2286+
{"src": "acc = []\nfor i in range(3):\n acc += [i]\nprint(acc)", "output": ["[0, 1, 2]"]},
2287+
{"src": "print(min([1, -5, 3], key=abs))", "output": ["1"]},
2288+
{"src": "print(max([\"a\", \"bbb\", \"cc\"], key=len))", "output": ["bbb"]},
2289+
{"src": "print(min([3, 1, 2], key=lambda x: -x))", "output": ["3"]},
2290+
{"src": "print(max([1, 2, 3], default=0, key=lambda x: -x))", "output": ["1"]},
2291+
{"src": "print(max([], default=-1, key=abs))", "output": ["-1"]},
2292+
{"src": "print(list(enumerate([\"a\", \"b\"], 5)))", "output": ["[(5, 'a'), (6, 'b')]"]},
2293+
{"src": "print(list(enumerate([\"a\", \"b\"], start=10)))", "output": ["[(10, 'a'), (11, 'b')]"]},
2294+
{"src": "for i, c in enumerate(\"xy\", 1):\n print(i, c)", "output": ["1 x", "2 y"]},
2295+
{"src": "it = iter([1])\nprint(next(it))\nprint(next(it, \"done\"))", "output": ["1", "done"]},
2296+
{"src": "print(next(iter([]), -1))", "output": ["-1"]},
2297+
{"src": "it = iter([5, 6])\nprint(next(it, 0), next(it, 0), next(it, 0))", "output": ["5 6 0"]},
2298+
{"src": "def make():\n n = [0]\n def f():\n n[0] += 1\n return n[0]\n return f\ng = make()\nprint(list(iter(g, 4)))", "output": ["[1, 2, 3]"]}
22872299
]

docs/content/reference/builtins.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ print(round(1.55, 1))
8282

8383
### min, max
8484

85-
Variadic or single iterable. Accept a `default=` returned when a single iterable is empty; without it an empty input raises `ValueError`. No `key=` (transform inline). Ordering follows `<`: numbers, strings, bytes, and tuples/lists (lexicographic).
85+
Variadic or single iterable. Accept a `default=` returned when a single iterable is empty; without it an empty input raises `ValueError`. A `key=` function selects the comparison value (the original element is returned). Ordering follows `<`: numbers, strings, bytes, and tuples/lists (lexicographic).
8686

8787
```python
8888
print(min(3, 1, 4))
@@ -346,7 +346,7 @@ print(reversed("abc"))
346346

347347
### enumerate
348348

349-
Pairs each element with its index -> list of `(i, value)` tuples. No `start=`, add the offset yourself.
349+
Pairs each element with its index -> list of `(i, value)` tuples. A second argument (positional or `start=`) sets the first index.
350350

351351
```python
352352
for i, v in enumerate(["a", "b", "c"]):
@@ -379,7 +379,7 @@ print(list(zip([1, 2], [3, 4], [5, 6])))
379379

380380
### next
381381

382-
`next(iterator)` -> next item. Exhausted -> `StopIteration`. Two-arg `next(it, default)` not supported.
382+
`next(iterator)` -> next item. Exhausted -> `StopIteration`. Two-arg `next(it, default)` returns `default` instead of raising on exhaustion.
383383

384384
```python
385385
it = iter([10, 20, 30])
@@ -396,7 +396,7 @@ print(next(it))
396396

397397
### iter
398398

399-
`iter(x)` returns a fresh iterator over any iterable (list, tuple, set, dict, range, str, bytes, frozenset). Materialises a snapshot, original never mutated. Two-arg `iter(callable, sentinel)` not supported.
399+
`iter(x)` returns a fresh iterator over any iterable (list, tuple, set, dict, range, str, bytes, frozenset). Materialises a snapshot, original never mutated. Two-arg `iter(callable, sentinel)` calls `callable()` until it returns `sentinel`, eagerly.
400400

401401
```python
402402
it = iter([1, 2, 3])

0 commit comments

Comments
 (0)