Skip to content

Commit 6165aad

Browse files
authored
Merge pull request RustPython#3695 from rebunto/optional-key-argument-for-_bisect-functions
Optional key argument for bisect functions
2 parents a019cbf + cc83db3 commit 6165aad

File tree

3 files changed

+144
-44
lines changed

3 files changed

+144
-44
lines changed

Lib/bisect.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
"""Bisection algorithms."""
22

3-
def insort_right(a, x, lo=0, hi=None):
3+
4+
def insort_right(a, x, lo=0, hi=None, *, key=None):
45
"""Insert item x in list a, and keep it sorted assuming a is sorted.
56
67
If x is already in a, insert it to the right of the rightmost x.
78
89
Optional args lo (default 0) and hi (default len(a)) bound the
910
slice of a to be searched.
1011
"""
11-
12-
if lo < 0:
13-
raise ValueError('lo must be non-negative')
14-
if hi is None:
15-
hi = len(a)
16-
while lo < hi:
17-
mid = (lo+hi)//2
18-
if x < a[mid]: hi = mid
19-
else: lo = mid+1
12+
if key is None:
13+
lo = bisect_right(a, x, lo, hi)
14+
else:
15+
lo = bisect_right(a, key(x), lo, hi, key=key)
2016
a.insert(lo, x)
2117

22-
def bisect_right(a, x, lo=0, hi=None):
18+
19+
def bisect_right(a, x, lo=0, hi=None, *, key=None):
2320
"""Return the index where to insert item x in list a, assuming a is sorted.
2421
2522
The return value i is such that all e in a[:i] have e <= x, and all e in
26-
a[i:] have e > x. So if x already appears in the list, a.insert(x) will
23+
a[i:] have e > x. So if x already appears in the list, a.insert(i, x) will
2724
insert just after the rightmost x already there.
2825
2926
Optional args lo (default 0) and hi (default len(a)) bound the
@@ -34,13 +31,26 @@ def bisect_right(a, x, lo=0, hi=None):
3431
raise ValueError('lo must be non-negative')
3532
if hi is None:
3633
hi = len(a)
37-
while lo < hi:
38-
mid = (lo+hi)//2
39-
if x < a[mid]: hi = mid
40-
else: lo = mid+1
34+
# Note, the comparison uses "<" to match the
35+
# __lt__() logic in list.sort() and in heapq.
36+
if key is None:
37+
while lo < hi:
38+
mid = (lo + hi) // 2
39+
if x < a[mid]:
40+
hi = mid
41+
else:
42+
lo = mid + 1
43+
else:
44+
while lo < hi:
45+
mid = (lo + hi) // 2
46+
if x < key(a[mid]):
47+
hi = mid
48+
else:
49+
lo = mid + 1
4150
return lo
4251

43-
def insort_left(a, x, lo=0, hi=None):
52+
53+
def insort_left(a, x, lo=0, hi=None, *, key=None):
4454
"""Insert item x in list a, and keep it sorted assuming a is sorted.
4555
4656
If x is already in a, insert it to the left of the leftmost x.
@@ -49,22 +59,17 @@ def insort_left(a, x, lo=0, hi=None):
4959
slice of a to be searched.
5060
"""
5161

52-
if lo < 0:
53-
raise ValueError('lo must be non-negative')
54-
if hi is None:
55-
hi = len(a)
56-
while lo < hi:
57-
mid = (lo+hi)//2
58-
if a[mid] < x: lo = mid+1
59-
else: hi = mid
62+
if key is None:
63+
lo = bisect_left(a, x, lo, hi)
64+
else:
65+
lo = bisect_left(a, key(x), lo, hi, key=key)
6066
a.insert(lo, x)
6167

62-
63-
def bisect_left(a, x, lo=0, hi=None):
68+
def bisect_left(a, x, lo=0, hi=None, *, key=None):
6469
"""Return the index where to insert item x in list a, assuming a is sorted.
6570
6671
The return value i is such that all e in a[:i] have e < x, and all e in
67-
a[i:] have e >= x. So if x already appears in the list, a.insert(x) will
72+
a[i:] have e >= x. So if x already appears in the list, a.insert(i, x) will
6873
insert just before the leftmost x already there.
6974
7075
Optional args lo (default 0) and hi (default len(a)) bound the
@@ -75,17 +80,31 @@ def bisect_left(a, x, lo=0, hi=None):
7580
raise ValueError('lo must be non-negative')
7681
if hi is None:
7782
hi = len(a)
78-
while lo < hi:
79-
mid = (lo+hi)//2
80-
if a[mid] < x: lo = mid+1
81-
else: hi = mid
83+
# Note, the comparison uses "<" to match the
84+
# __lt__() logic in list.sort() and in heapq.
85+
if key is None:
86+
while lo < hi:
87+
mid = (lo + hi) // 2
88+
if a[mid] < x:
89+
lo = mid + 1
90+
else:
91+
hi = mid
92+
else:
93+
while lo < hi:
94+
mid = (lo + hi) // 2
95+
if key(a[mid]) < x:
96+
lo = mid + 1
97+
else:
98+
hi = mid
8299
return lo
83100

84-
# Overwrite above definitions with a fast Rust implementation
101+
102+
# Overwrite above definitions with a fast C implementation
85103
try:
86104
from _bisect import *
87105
except ImportError:
88106
pass
89107

90-
bisect = bisect_right # backward compatibility
91-
insort = insort_right # backward compatibility
108+
# Create aliases
109+
bisect = bisect_right
110+
insort = insort_right

Lib/test/test_bisect.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import sys
22
import unittest
3-
from test import support
43
from test.support import import_helper
54
from collections import UserList
65

6+
77
py_bisect = import_helper.import_fresh_module('bisect', blocked=['_bisect'])
88
c_bisect = import_helper.import_fresh_module('bisect', fresh=['bisect'])
99

@@ -200,6 +200,63 @@ def test_keyword_args(self):
200200
self.module.insort(a=data, x=25, lo=1, hi=3)
201201
self.assertEqual(data, [10, 20, 25, 25, 25, 30, 40, 50])
202202

203+
def test_lookups_with_key_function(self):
204+
mod = self.module
205+
206+
# Invariant: Index with a keyfunc on an array
207+
# should match the index on an array where
208+
# key function has already been applied.
209+
210+
keyfunc = abs
211+
arr = sorted([2, -4, 6, 8, -10], key=keyfunc)
212+
precomputed_arr = list(map(keyfunc, arr))
213+
for x in precomputed_arr:
214+
self.assertEqual(
215+
mod.bisect_left(arr, x, key=keyfunc),
216+
mod.bisect_left(precomputed_arr, x)
217+
)
218+
self.assertEqual(
219+
mod.bisect_right(arr, x, key=keyfunc),
220+
mod.bisect_right(precomputed_arr, x)
221+
)
222+
223+
keyfunc = str.casefold
224+
arr = sorted('aBcDeEfgHhiIiij', key=keyfunc)
225+
precomputed_arr = list(map(keyfunc, arr))
226+
for x in precomputed_arr:
227+
self.assertEqual(
228+
mod.bisect_left(arr, x, key=keyfunc),
229+
mod.bisect_left(precomputed_arr, x)
230+
)
231+
self.assertEqual(
232+
mod.bisect_right(arr, x, key=keyfunc),
233+
mod.bisect_right(precomputed_arr, x)
234+
)
235+
236+
def test_insort(self):
237+
from random import shuffle
238+
mod = self.module
239+
240+
# Invariant: As random elements are inserted in
241+
# a target list, the targetlist remains sorted.
242+
keyfunc = abs
243+
data = list(range(-10, 11)) + list(range(-20, 20, 2))
244+
shuffle(data)
245+
target = []
246+
for x in data:
247+
mod.insort_left(target, x, key=keyfunc)
248+
self.assertEqual(
249+
sorted(target, key=keyfunc),
250+
target
251+
)
252+
target = []
253+
for x in data:
254+
mod.insort_right(target, x, key=keyfunc)
255+
self.assertEqual(
256+
sorted(target, key=keyfunc),
257+
target
258+
)
259+
203260
class TestBisectPython(TestBisect, unittest.TestCase):
204261
module = py_bisect
205262

stdlib/src/bisect.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ mod _bisect {
1414
lo: OptionalArg<PyObjectRef>,
1515
#[pyarg(any, optional)]
1616
hi: OptionalArg<PyObjectRef>,
17+
#[pyarg(named, default)]
18+
key: Option<PyObjectRef>,
1719
}
1820

1921
// Handles objects that implement __index__ and makes sure index fits in needed isize.
@@ -66,17 +68,21 @@ mod _bisect {
6668
#[inline]
6769
#[pyfunction]
6870
fn bisect_left(
69-
BisectArgs { a, x, lo, hi }: BisectArgs,
71+
BisectArgs { a, x, lo, hi, key }: BisectArgs,
7072
vm: &VirtualMachine,
7173
) -> PyResult<usize> {
7274
let (mut lo, mut hi) = as_usize(lo, hi, a.length(vm)?, vm)?;
7375

7476
while lo < hi {
7577
// Handles issue 13496.
7678
let mid = (lo + hi) / 2;
77-
if a.get_item(&mid, vm)?
78-
.rich_compare_bool(&x, PyComparisonOp::Lt, vm)?
79-
{
79+
let a_mid = a.get_item(&mid, vm)?;
80+
let comp = if let Some(ref key) = key {
81+
vm.invoke(key, (a_mid,))?
82+
} else {
83+
a_mid
84+
};
85+
if comp.rich_compare_bool(&x, PyComparisonOp::Lt, vm)? {
8086
lo = mid + 1;
8187
} else {
8288
hi = mid;
@@ -96,15 +102,21 @@ mod _bisect {
96102
#[inline]
97103
#[pyfunction]
98104
fn bisect_right(
99-
BisectArgs { a, x, lo, hi }: BisectArgs,
105+
BisectArgs { a, x, lo, hi, key }: BisectArgs,
100106
vm: &VirtualMachine,
101107
) -> PyResult<usize> {
102108
let (mut lo, mut hi) = as_usize(lo, hi, a.length(vm)?, vm)?;
103109

104110
while lo < hi {
105111
// Handles issue 13496.
106112
let mid = (lo + hi) / 2;
107-
if x.rich_compare_bool(&*a.get_item(&mid, vm)?, PyComparisonOp::Lt, vm)? {
113+
let a_mid = a.get_item(&mid, vm)?;
114+
let comp = if let Some(ref key) = key {
115+
vm.invoke(key, (a_mid,))?
116+
} else {
117+
a_mid
118+
};
119+
if x.rich_compare_bool(&*comp, PyComparisonOp::Lt, vm)? {
108120
hi = mid;
109121
} else {
110122
lo = mid + 1;
@@ -120,13 +132,19 @@ mod _bisect {
120132
/// Optional args lo (default 0) and hi (default len(a)) bound the
121133
/// slice of a to be searched.
122134
#[pyfunction]
123-
fn insort_left(BisectArgs { a, x, lo, hi }: BisectArgs, vm: &VirtualMachine) -> PyResult {
135+
fn insort_left(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
136+
let x = if let Some(ref key) = key {
137+
vm.invoke(key, (x,))?
138+
} else {
139+
x
140+
};
124141
let index = bisect_left(
125142
BisectArgs {
126143
a: a.clone(),
127144
x: x.clone(),
128145
lo,
129146
hi,
147+
key,
130148
},
131149
vm,
132150
)?;
@@ -140,13 +158,19 @@ mod _bisect {
140158
/// Optional args lo (default 0) and hi (default len(a)) bound the
141159
/// slice of a to be searched
142160
#[pyfunction]
143-
fn insort_right(BisectArgs { a, x, lo, hi }: BisectArgs, vm: &VirtualMachine) -> PyResult {
161+
fn insort_right(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
162+
let x = if let Some(ref key) = key {
163+
vm.invoke(key, (x,))?
164+
} else {
165+
x
166+
};
144167
let index = bisect_right(
145168
BisectArgs {
146169
a: a.clone(),
147170
x: x.clone(),
148171
lo,
149172
hi,
173+
key,
150174
},
151175
vm,
152176
)?;

0 commit comments

Comments
 (0)