@@ -330,6 +330,10 @@ def py_flatten(arr):
330330 f .empty ,
331331 lambda data : [len (r ) == 0 for r in data ],
332332 ),
333+ (
334+ f .list_empty ,
335+ lambda data : [len (r ) == 0 for r in data ],
336+ ),
333337 (
334338 lambda col : f .array_extract (col , literal (1 )),
335339 lambda data : [r [0 ] for r in data ],
@@ -354,18 +358,42 @@ def py_flatten(arr):
354358 lambda col : f .array_has (col , literal (1.0 )),
355359 lambda data : [1.0 in r for r in data ],
356360 ),
361+ (
362+ lambda col : f .list_has (col , literal (1.0 )),
363+ lambda data : [1.0 in r for r in data ],
364+ ),
365+ (
366+ lambda col : f .array_contains (col , literal (1.0 )),
367+ lambda data : [1.0 in r for r in data ],
368+ ),
369+ (
370+ lambda col : f .list_contains (col , literal (1.0 )),
371+ lambda data : [1.0 in r for r in data ],
372+ ),
357373 (
358374 lambda col : f .array_has_all (
359375 col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
360376 ),
361377 lambda data : [np .all ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
362378 ),
379+ (
380+ lambda col : f .list_has_all (
381+ col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
382+ ),
383+ lambda data : [np .all ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
384+ ),
363385 (
364386 lambda col : f .array_has_any (
365387 col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
366388 ),
367389 lambda data : [np .any ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
368390 ),
391+ (
392+ lambda col : f .list_has_any (
393+ col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
394+ ),
395+ lambda data : [np .any ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
396+ ),
369397 (
370398 lambda col : f .array_position (col , literal (1.0 )),
371399 lambda data : [py_indexof (r , 1.0 ) for r in data ],
@@ -418,10 +446,18 @@ def py_flatten(arr):
418446 f .array_pop_back ,
419447 lambda data : [arr [:- 1 ] for arr in data ],
420448 ),
449+ (
450+ f .list_pop_back ,
451+ lambda data : [arr [:- 1 ] for arr in data ],
452+ ),
421453 (
422454 f .array_pop_front ,
423455 lambda data : [arr [1 :] for arr in data ],
424456 ),
457+ (
458+ f .list_pop_front ,
459+ lambda data : [arr [1 :] for arr in data ],
460+ ),
425461 (
426462 lambda col : f .array_remove (col , literal (3.0 )),
427463 lambda data : [py_arr_remove (arr , 3.0 , 1 ) for arr in data ],
@@ -1565,58 +1601,3 @@ def test_gen_series_with_step():
15651601 f .gen_series (literal (1 ), literal (10 ), literal (3 )).alias ("v" )
15661602 ).collect ()
15671603 assert result [0 ].column (0 )[0 ].as_py () == [1 , 4 , 7 , 10 ]
1568-
1569-
1570- @pytest .mark .parametrize (
1571- ("func" , "element" , "expected" ),
1572- [
1573- (f .array_contains , literal (2 ), True ),
1574- (f .list_contains , literal (99 ), False ),
1575- (f .list_has , literal (2 ), True ),
1576- ],
1577- )
1578- def test_element_containment (func , element , expected ):
1579- ctx = SessionContext ()
1580- df = ctx .from_pydict ({"a" : [[1 , 2 , 3 ]]})
1581- result = df .select (func (column ("a" ), element ).alias ("v" )).collect ()
1582- assert result [0 ].column (0 )[0 ].as_py () is expected
1583-
1584-
1585- def test_list_has_all ():
1586- ctx = SessionContext ()
1587- df = ctx .from_pydict ({"a" : [[1 , 2 , 3 ]]})
1588- result = df .select (
1589- f .list_has_all (column ("a" ), f .make_array (literal (1 ), literal (2 ))).alias ("v" )
1590- ).collect ()
1591- assert result [0 ].column (0 )[0 ].as_py () is True
1592-
1593-
1594- def test_list_has_any ():
1595- ctx = SessionContext ()
1596- df = ctx .from_pydict ({"a" : [[1 , 2 , 3 ]]})
1597- result = df .select (
1598- f .list_has_any (column ("a" ), f .make_array (literal (5 ), literal (2 ))).alias ("v" )
1599- ).collect ()
1600- assert result [0 ].column (0 )[0 ].as_py () is True
1601-
1602-
1603- def test_list_empty ():
1604- ctx = SessionContext ()
1605- df = ctx .from_pydict ({"a" : [[], [1 , 2 ]]})
1606- result = df .select (f .list_empty (column ("a" )).alias ("v" )).collect ()
1607- values = [row .as_py () for row in result [0 ].column (0 )]
1608- assert values == [True , False ]
1609-
1610-
1611- @pytest .mark .parametrize (
1612- ("func" , "expected" ),
1613- [
1614- (f .list_pop_back , [1 , 2 ]),
1615- (f .list_pop_front , [2 , 3 ]),
1616- ],
1617- )
1618- def test_list_pop (func , expected ):
1619- ctx = SessionContext ()
1620- df = ctx .from_pydict ({"a" : [[1 , 2 , 3 ]]})
1621- result = df .select (func (column ("a" )).alias ("v" )).collect ()
1622- assert result [0 ].column (0 )[0 ].as_py () == expected
0 commit comments