Skip to content

Commit edfb9a8

Browse files
committed
Implement a pyimport string macro
This helps usability by automatically converting Python import statements into their Julia equivalents.
1 parent 4adeb0a commit edfb9a8

File tree

6 files changed

+215
-0
lines changed

6 files changed

+215
-0
lines changed

docs/src/pythoncall-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pyhasitem
5555
pyhash
5656
pyhelp
5757
pyimport
58+
@pyimport_str
5859
pyin
5960
pyis
6061
pyisinstance

src/API/exports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export pyilshift
5050
export pyimatmul
5151
export pyimod
5252
export pyimport
53+
export @pyimport_str
5354
export pyimul
5455
export pyin
5556
export pyindex

src/API/macros.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
macro pyconst end
33
macro pyeval end
44
macro pyexec end
5+
macro pyimport_str end
56

67
# Convert
78
macro pyconvert end

src/Core/Core.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import ..PythonCall:
3434
@pyconst,
3535
@pyeval,
3636
@pyexec,
37+
@pyimport_str,
3738
getptr,
3839
ispy,
3940
Py,

src/Core/builtins.jl

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,13 +1454,151 @@ end
14541454
Import a module `m`, or an attribute `k`, or a tuple of attributes.
14551455
14561456
If several arguments are given, return the results of importing each one in a tuple.
1457+
1458+
See also: [`@pyimport_str`](@ref).
14571459
"""
14581460
pyimport(m) = pynew(errcheck(@autopy m C.PyImport_Import(m_)))
14591461
pyimport((m, k)::Pair) = (m_ = pyimport(m); k_ = pygetattr(m_, k); pydel!(m_); k_)
14601462
pyimport((m, ks)::Pair{<:Any,<:Tuple}) =
14611463
(m_ = pyimport(m); ks_ = map(k -> pygetattr(m_, k), ks); pydel!(m_); ks_)
14621464
pyimport(m1, m2, ms...) = map(pyimport, (m1, m2, ms...))
14631465

1466+
"""
1467+
pyimport"import numpy"
1468+
pyimport"import numpy as np"
1469+
pyimport"from numpy import array"
1470+
pyimport"from numpy import array as arr"
1471+
pyimport"from numpy import array, zeros"
1472+
pyimport"from numpy import array as arr, zeros as z"
1473+
pyimport"import numpy, scipy"
1474+
1475+
String macro that parses Python import syntax and generates equivalent Julia
1476+
code using [`pyimport()`](@ref). Each form generates `const` bindings in the
1477+
caller's scope.
1478+
1479+
Multiple lines are supported:
1480+
```julia
1481+
pyimport\"\"\"
1482+
import numpy as np
1483+
from scipy import linalg, optimize
1484+
from os.path import join as pathjoin
1485+
\"\"\"
1486+
1487+
# Converted to:
1488+
const np = pyimport("numpy")
1489+
const linalg = pyimport("scipy" => "lingalg")
1490+
const optimize = pyimport("scipy" => "optimize")
1491+
const pathjoin = pyimport("os.path" => "join")
1492+
```
1493+
1494+
But multiline or grouped import statements are not supported:
1495+
```julia
1496+
# These will throw an error
1497+
pyimport\"\"\"
1498+
from sys import (path,
1499+
version)
1500+
from sys import path, \
1501+
version
1502+
\"\"\"
1503+
```
1504+
"""
1505+
macro pyimport_str(s)
1506+
esc(_pyimport_parse(s))
1507+
end
1508+
1509+
function _pyimport_parse(s::AbstractString)
1510+
lines = filter(!isempty, strip.(split(s, "\n")))
1511+
isempty(lines) && throw(ArgumentError("pyimport: empty import string"))
1512+
1513+
# Check for line continuations
1514+
for line in lines
1515+
if contains(line, '\\') || contains(line, '(')
1516+
throw(ArgumentError("pyimport: line continuation with '\\' or '(' is not supported: $line"))
1517+
end
1518+
end
1519+
1520+
if length(lines) == 1
1521+
_pyimport_parse_line(lines[1])
1522+
else
1523+
Expr(:block, [_pyimport_parse_line(line) for line in lines]...)
1524+
end
1525+
end
1526+
1527+
function _pyimport_parse_line(s::AbstractString)
1528+
if startswith(s, "from ")
1529+
_pyimport_parse_from(s)
1530+
elseif startswith(s, "import ")
1531+
_pyimport_parse_import(s)
1532+
else
1533+
throw(ArgumentError("pyimport: expected 'import ...' or 'from ... import ...', got: $s"))
1534+
end
1535+
end
1536+
1537+
function _pyimport_parse_import(s::AbstractString)
1538+
rest = strip(chopprefix(s, "import"))
1539+
if isempty(rest)
1540+
throw(ArgumentError("pyimport: missing module name after 'import'"))
1541+
end
1542+
1543+
parts = split(rest, ","; keepempty=false)
1544+
exprs = Expr[]
1545+
for part in parts
1546+
part = strip(part)
1547+
1548+
# Check if there's an `as` clause
1549+
m = match(r"^(\S+)\s+as\s+(\S+)$", part)
1550+
if !isnothing(m)
1551+
# `import numpy.linalg as la` binds la to the linalg submodule
1552+
modname = m[1]
1553+
alias = Symbol(m[2])
1554+
push!(exprs, :(const $alias = pyimport($modname)))
1555+
else
1556+
modname = part
1557+
# `import numpy.linalg` binds numpy (top-level package)
1558+
# but first imports the submodule to ensure it's loaded
1559+
dotparts = split(modname, ".")
1560+
alias = Symbol(dotparts[1])
1561+
if length(dotparts) == 1
1562+
push!(exprs, :(const $alias = pyimport($modname)))
1563+
else
1564+
toplevel = dotparts[1]
1565+
push!(exprs, :(const $alias = (pyimport($modname); pyimport($toplevel))))
1566+
end
1567+
end
1568+
end
1569+
1570+
length(exprs) == 1 ? exprs[1] : Expr(:block, exprs...)
1571+
end
1572+
1573+
function _pyimport_parse_from(s::AbstractString)
1574+
m = match(r"^from\s+(\S+)\s+import\s+(.+)$", s)
1575+
if isnothing(m)
1576+
throw(ArgumentError("pyimport: invalid from-import syntax: $s"))
1577+
end
1578+
1579+
modname = m[1]
1580+
rest = strip(m[2])
1581+
if rest == "*"
1582+
throw(ArgumentError("pyimport: wildcard import 'from $modname import *' is not supported"))
1583+
end
1584+
1585+
parts = split(rest, ","; keepempty=false)
1586+
exprs = Expr[]
1587+
for part in parts
1588+
part = strip(part)
1589+
m2 = match(r"^(\S+)\s+as\s+(\S+)$", part)
1590+
name, alias = if !isnothing(m2)
1591+
m2[1], Symbol(m2[2])
1592+
else
1593+
part, Symbol(part)
1594+
end
1595+
1596+
push!(exprs, :(const $alias = pyimport($modname => $name)))
1597+
end
1598+
1599+
length(exprs) == 1 ? exprs[1] : Expr(:block, exprs...)
1600+
end
1601+
14641602
### builtins not covered elsewhere
14651603

14661604
"""

test/Core.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,79 @@ end
428428
@test pyis(verpath[2], path)
429429
end
430430

431+
@testset "pyimport_str" begin
432+
parse = PythonCall.Core._pyimport_parse
433+
434+
# Import a module
435+
@test parse("import sys") == :(const sys = pyimport("sys"))
436+
437+
# Import module as an alias
438+
@test parse("import sys as system") == :(const system = pyimport("sys"))
439+
440+
# Importing a submodule binds the top-level package
441+
ex = parse("import os.path") |> Base.remove_linenums!
442+
@test ex == :(const os = begin
443+
pyimport("os.path")
444+
pyimport("os")
445+
end) |> Base.remove_linenums!
446+
447+
# Importing a submodule as an alias binds the submodule
448+
@test parse("import os.path as osp") == :(const osp = pyimport("os.path"))
449+
450+
# import multiple modules
451+
ex = parse("import sys, os") |> Base.remove_linenums!
452+
@test ex == quote
453+
const sys = pyimport("sys")
454+
const os = pyimport("os")
455+
end |> Base.remove_linenums!
456+
457+
# from module import name
458+
@test parse("from sys import path") == :(const path = pyimport("sys" => "path"))
459+
460+
# from module import name as alias
461+
@test parse("from sys import path as p") == :(const p = pyimport("sys" => "path"))
462+
463+
# from module import multiple names
464+
ex = parse("from sys import path, version") |> Base.remove_linenums!
465+
@test ex == quote
466+
const path = pyimport("sys" => "path")
467+
const version = pyimport("sys" => "version")
468+
end |> Base.remove_linenums!
469+
470+
# from module import multiple names with aliases
471+
ex = parse("from sys import path as p, version as v") |> Base.remove_linenums!
472+
@test ex == quote
473+
const p = pyimport("sys" => "path")
474+
const v = pyimport("sys" => "version")
475+
end |> Base.remove_linenums!
476+
477+
# from dotted module import name
478+
@test parse("from os.path import join") == :(const join = pyimport("os.path" => "join"))
479+
480+
# Multiple lines, with extra whitespace
481+
ex = parse("import sys \n from os import getcwd ") |> Base.remove_linenums!
482+
@test ex == quote
483+
const sys = pyimport("sys")
484+
const getcwd = pyimport("os" => "getcwd")
485+
end |> Base.remove_linenums!
486+
487+
# Error cases
488+
@test_throws ArgumentError parse("")
489+
@test_throws ArgumentError parse("not an import")
490+
@test_throws ArgumentError parse("from os import *")
491+
@test_throws ArgumentError parse("import")
492+
493+
# Line continuations are not supported
494+
@test_throws ArgumentError parse("from os import \\\n path, getcwd")
495+
@test_throws ArgumentError parse("from os import (\n path, getcwd\n)")
496+
497+
# smoke test: actually run the macro
498+
m = Module()
499+
@eval m using PythonCall
500+
@eval m pyimport"import sys"
501+
@test pyeq(Bool, m.sys.__name__, "sys")
502+
end
503+
431504
@testitem "consts" begin
432505
@test pybuiltins.None isa Py
433506
@test pystr(String, pybuiltins.None) == "None"

0 commit comments

Comments
 (0)