From ff69cfb08c9255903cdd2a920841feac197c7fcc Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 19 Dec 2025 12:58:52 -0500 Subject: [PATCH 1/2] translate and lower jvp --- .github/workflows/check-catalyst.yaml | 6 ++-- frontend/catalyst/from_plxpr/from_plxpr.py | 14 +++++++++ frontend/catalyst/jax_primitives.py | 35 ++++++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index b862198ab2..b4de683101 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -477,7 +477,7 @@ jobs: - name: Install PennyLane branch run: | - pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-jvp - name: Get Cached LLVM Build @@ -576,7 +576,7 @@ jobs: - name: Install PennyLane branch run: | - pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-jvp - name: Get Cached LLVM Build id: cache-llvm-build @@ -642,7 +642,7 @@ jobs: - name: Install PennyLane branch run: | - pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-jvp - name: Get Cached LLVM Build id: cache-llvm-build diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 327a6d7133..529f0aa5d4 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -30,6 +30,7 @@ from pennylane.capture.primitives import jacobian_prim as pl_jac_prim from pennylane.capture.primitives import transform_prim from pennylane.capture.primitives import vjp_prim as pl_vjp_prim +from pennylane.capture.primitives import jvp_prim as pl_jvp_prim from pennylane.transforms import commute_controlled as pl_commute_controlled from pennylane.transforms import decompose as pl_decompose from pennylane.transforms import gridsynth as pl_gridsynth @@ -252,6 +253,19 @@ def handle_vjp(self, *args, jaxpr, **kwargs): return pl_vjp_prim.bind(*new_args, jaxpr=new_j, **kwargs) +@WorkflowInterpreter.register_primitive(pl_jvp_prim) +def handle_jvp(self, *args, jaxpr, **kwargs): + """Translate a grad equation.""" + f = partial(copy(self).eval, jaxpr, []) + new_jaxpr = jax.make_jaxpr(f)(*args[: -len(jaxpr.outvars)]) + + new_args = (*new_jaxpr.consts, *args) + j = new_jaxpr.jaxpr + new_j = j.replace(constvars=(), invars=j.constvars + j.invars) + return pl_jvp_prim.bind(*new_args, jaxpr=new_j, **kwargs) + + + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 3cc9758e45..71a116ad8a 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -134,6 +134,7 @@ from pennylane.capture.primitives import jacobian_prim as pl_jac_prim from pennylane.capture.primitives import vjp_prim as pl_vjp_prim +from pennylane.capture.primitives import jvp_prim as pl_jvp_prim from catalyst.compiler import get_lib_path from catalyst.jax_extras import ( @@ -875,6 +876,39 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params): finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results +def _capture_jvp_lowering(ctx, *args, jaxpr, fn, method, argnums, h): + """ + Returns: + MLIR results + """ + args = list(args) + mlir_ctx = ctx.module_context.context + n_params = len(jaxpr.invars) + array_argnums = np.array(argnums) + + output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) + flat_output_types = util.flatten(output_types) + + func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)] + jvp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :] + + func_args = args[:n_params] + d_args = args[n_params:] + + func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn) + + symbol_ref = get_symbolref(ctx, func_op) + return JVPOp( + func_result_types, + jvp_result_types, + ir.StringAttr.get(method), + symbol_ref, + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(d_args), + diffArgIndices=ir.DenseIntElementsAttr.get(array_argnums), + finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, + ).results + @vjp_p.def_impl def _vjp_def_impl(ctx, *args, jaxpr, fn, grad_params): # pragma: no cover @@ -2907,6 +2941,7 @@ def subroutine_lowering(*args, **kwargs): (grad_p, _grad_lowering), (pl_jac_prim, _capture_grad_lowering), (pl_vjp_prim, _capture_vjp_lowering), + (pl_jvp_prim, _capture_jvp_lowering), (func_p, _func_lowering), (jvp_p, _jvp_lowering), (vjp_p, _vjp_lowering), From daf3399419d2a9a76401ee5415555cccd1cbe17e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 19 Dec 2025 16:41:47 -0500 Subject: [PATCH 2/2] more validation --- .../api_extensions/differentiation.py | 22 +++--- frontend/catalyst/from_plxpr/from_plxpr.py | 2 +- frontend/test/pytest/test_jvpvjp.py | 67 ++++++++++++------- 3 files changed, 53 insertions(+), 38 deletions(-) diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py index 734de44c1c..f393c564ca 100644 --- a/frontend/catalyst/api_extensions/differentiation.py +++ b/frontend/catalyst/api_extensions/differentiation.py @@ -21,7 +21,7 @@ import copy import functools import numbers -from typing import Callable, Iterable, List, Optional, Union +from typing import Callable, Sequence, List, Optional, Union import jax import pennylane as qml @@ -442,12 +442,12 @@ def workflow(primals, tangents): (Array(0.78766064, dtype=float64), Array(-0.7011436, dtype=float64)) """ - def check_is_iterable(x, hint): - if not isinstance(x, Iterable): - raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}") + def check_is_Sequence(x, hint): + if not isinstance(x, Sequence): + raise ValueError(f"vjp '{hint}' argument must be a Sequence, not {type(x)}") - check_is_iterable(params, "params") - check_is_iterable(tangents, "tangents") + check_is_Sequence(params, "params") + check_is_Sequence(tangents, "tangents") if EvaluationContext.is_tracing(): scalar_out = False @@ -550,12 +550,12 @@ def f(x): if qml.capture.enabled(): return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums) - def check_is_iterable(x, hint): - if not isinstance(x, Iterable): - raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}") + def check_is_Sequence(x, hint): + if not isinstance(x, Sequence): + raise ValueError(f"vjp '{hint}' argument must be a Sequence, not {type(x)}") - check_is_iterable(params, "params") - check_is_iterable(cotangents, "cotangents") + check_is_Sequence(params, "params") + check_is_Sequence(cotangents, "cotangents") if EvaluationContext.is_tracing(): scalar_out = False diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 529f0aa5d4..25d247e02e 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -257,7 +257,7 @@ def handle_vjp(self, *args, jaxpr, **kwargs): def handle_jvp(self, *args, jaxpr, **kwargs): """Translate a grad equation.""" f = partial(copy(self).eval, jaxpr, []) - new_jaxpr = jax.make_jaxpr(f)(*args[: -len(jaxpr.outvars)]) + new_jaxpr = jax.make_jaxpr(f)(*args[: len(jaxpr.invars)]) new_args = (*new_jaxpr.consts, *args) j = new_jaxpr.jaxpr diff --git a/frontend/test/pytest/test_jvpvjp.py b/frontend/test/pytest/test_jvpvjp.py index 627b89dfc8..4abd60059d 100644 --- a/frontend/test/pytest/test_jvpvjp.py +++ b/frontend/test/pytest/test_jvpvjp.py @@ -62,7 +62,7 @@ def f(x): ct = jnp.array(1.0) res, f_vjp = jax.vjp(f, x) expected = tuple([res, f_vjp(ct)]) - result = C_vjp(f, x, ct) + result = C_vjp(f, (x,), ct) res_jax, tree_jax = jax.tree_util.tree_flatten(expected) res_cat, tree_cat = jax.tree_util.tree_flatten(result) @@ -159,6 +159,7 @@ def f(x, y): assert_allclose(expected, result) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_against_jax_full_argnum_case_S_SS(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -171,7 +172,7 @@ def test_jvp_against_jax_full_argnum_case_S_SS(diff_method): @qjit def C_workflow(): f = qml.QNode(circuit_rx, device=qml.device("lightning.qubit", wires=1)) - return C_jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) + return qml.jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -183,9 +184,10 @@ def J_workflow(): res_jax, tree_jax = jax.tree_util.tree_flatten(r1) res_cat, tree_cat = jax.tree_util.tree_flatten(r2) assert tree_jax == tree_cat - assert_allclose(res_jax, res_cat) + assert_allclose(res_jax, res_cat, atol=5e-7) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_against_jax_full_argnum_case_T_T(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -200,7 +202,7 @@ def f(x): @qjit def C_workflow(): - return C_jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) + return qml.jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -214,6 +216,7 @@ def J_workflow(): assert_allclose(res_jax, res_cat) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_against_jax_full_argnum_case_TT_T(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -233,7 +236,7 @@ def f(x1, x2): @qjit def C_workflow(): - return C_jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) + return qml.jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -247,6 +250,7 @@ def J_workflow(): assert_allclose(res_jax, res_cat) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_against_jax_full_argnum_case_T_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -261,7 +265,7 @@ def f(x): @qjit def C_workflow(): - return C_jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) + return qml.jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -277,6 +281,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_against_jax_full_argnum_case_TT_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -299,7 +304,7 @@ def f(x1, x2): @qjit def C_workflow(): - return C_jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) + return qml.jvp(f, x, t, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -315,6 +320,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_pytrees_return(diff_method): """Test that a JVP with pytrees as return.""" @@ -324,7 +330,7 @@ def f(x, y): @qjit def workflow(): - return C_jvp(f, [0.1, 0.2], [1.0, 1.0], method=diff_method, argnums=[0, 1]) + return qml.jvp(f, [0.1, 0.2], [1.0, 1.0], method=diff_method, argnums=[0, 1]) catalyst_res = workflow() jax_res = J_jvp(f, [0.1, 0.2], [1.0, 1.0]) @@ -335,6 +341,7 @@ def workflow(): assert_allclose(catalyst_res_flatten, jax_res_flatten) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_pytrees_args(diff_method): """Test that a JVP with pytrees as args.""" @@ -344,7 +351,7 @@ def f(x, y): @qjit def workflow(): - return C_jvp( + return qml.jvp( f, [{"res1": 0.1, "res2": 0.2}, 0.3], [{"res1": 1.0, "res2": 1.0}, 1.0], @@ -361,6 +368,7 @@ def workflow(): assert_allclose(catalyst_res_flatten, jax_res_flatten) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_pytrees_args_and_return(diff_method): """Test that a JVP with pytrees as args.""" @@ -370,7 +378,7 @@ def f(x, y): @qjit def workflow(): - return C_jvp( + return qml.jvp( f, [{"res1": 0.1, "res2": 0.2}, 0.3], [{"res1": 1.0, "res2": 1.0}, 1.0], @@ -387,6 +395,7 @@ def workflow(): assert_allclose(catalyst_res_flatten, jax_res_flatten) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_multi_returns(diff_method): """Test that a JVP with multiple arg as return.""" @@ -396,7 +405,7 @@ def f(x): @qjit def workflow(): - return C_jvp(f, [0.3], [1.1], method=diff_method, argnums=[0]) + return qml.jvp(f, [0.3], [1.1], method=diff_method, argnums=[0]) catalyst_res = workflow() jax_res = J_jvp(f, [0.3], [1.1]) @@ -404,7 +413,7 @@ def workflow(): catalyst_res_flatten, tree_cat = jax.tree_util.tree_flatten(catalyst_res) jax_res_flatten, tree_jax = jax.tree_util.tree_flatten(jax_res) assert tree_cat == tree_jax - assert_allclose(catalyst_res_flatten, jax_res_flatten, rtol=1e-6) + assert_allclose(catalyst_res_flatten, jax_res_flatten, rtol=1e-6, atol=5e-6) @pytest.mark.usefixtures("use_both_frontend") @@ -578,6 +587,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvpvjp_argument_checks(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -598,11 +608,11 @@ def f(x1, x2): @qjit def C_workflow1(): - return C_jvp(f, x, tuple(t), method=diff_method, argnums=list(range(len(x)))) + return qml.jvp(f, x, tuple(t), method=diff_method, argnums=list(range(len(x)))) @qjit def C_workflow2(): - return C_jvp(f, tuple(x), t, method=diff_method, argnums=tuple(range(len(x)))) + return qml.jvp(f, tuple(x), t, method=diff_method, argnums=tuple(range(len(x)))) r1 = C_workflow1() r2 = C_workflow2() @@ -612,13 +622,13 @@ def C_workflow2(): for r_j, r_c in zip(res_jax, res_cat): assert_allclose(r_j, r_c) - with pytest.raises(ValueError, match="argument must be an iterable"): + with pytest.raises(ValueError, match="must be a Sequence"): @qjit def C_workflow_bad1(): - return C_jvp(f, 33, tuple(t), argnums=list(range(len(x)))) + return qml.jvp(f, 33, tuple(t), argnums=list(range(len(x)))) - with pytest.raises(ValueError, match="argument must be an iterable"): + with pytest.raises(ValueError, match="must be a Sequence"): @qjit def C_workflow_bad2(): @@ -631,6 +641,7 @@ def C_workflow_bad3(): return qml.vjp(f, x, ct, argnums="invalid") +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_against_jax_argnum0_case_TT_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version, in case of empty or singular @@ -654,11 +665,11 @@ def f(x1, x2): @qjit def C_workflowA(): - return C_jvp(f, x, t[0:1], method=diff_method) + return qml.jvp(f, x, t[0:1], method=diff_method) @qjit def C_workflowB(): - return C_jvp(f, x, t[0:1], method=diff_method, argnums=[0]) + return qml.jvp(f, x, t[0:1], method=diff_method, argnums=[0]) @jax.jit def J_workflow(): @@ -856,6 +867,7 @@ def J_workflow(): assert_allclose(r_j, r_c, atol=2e-6) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_argument_type_checks_correct_inputs(diff_method): """Test that Catalyst's jvp can JIT compile when given the correct types.""" @@ -864,15 +876,16 @@ def test_jvp_argument_type_checks_correct_inputs(diff_method): def C_workflow_f(): x = (1.0,) tangents = (1.0,) - return C_jvp(f_R1_to_R2, x, tangents, method=diff_method, argnums=[0]) + return qml.jvp(f_R1_to_R2, x, tangents, method=diff_method, argnums=[0]) @qjit def C_workflow_g(): x = jnp.array([2.0, 3.0, 4.0]) tangents = jnp.ones([3], dtype=float) - return C_jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnums=[1]) + return qml.jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnums=[1]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_argument_type_checks_incompatible_n_inputs(diff_method): """Tests error handling of Catalyst's jvp when the number of differentiable params @@ -889,9 +902,10 @@ def C_workflow(): # If `f` takes one differentiable param (argnum=[0]), then `tangents` must have length 1 x = (1.0,) tangents = (1.0, 1.0) - return C_jvp(f_R1_to_R2, x, tangents, method=diff_method, argnums=[0]) + return qml.jvp(f_R1_to_R2, x, tangents, method=diff_method, argnums=[0]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_argument_type_checks_incompatible_input_types(diff_method): """Tests error handling of Catalyst's jvp when the types of the differentiable @@ -899,7 +913,7 @@ def test_jvp_argument_type_checks_incompatible_input_types(diff_method): """ with pytest.raises( - TypeError, match="function params and tangents arguments to catalyst.jvp do not match" + TypeError, match="function params and tangents arguments to " ): @qjit @@ -907,9 +921,10 @@ def C_workflow(): # If `x` has type float, then `tangents` should also have type float x = (1.0,) tangents = (1,) - return C_jvp(f_R1_to_R2, x, tangents, method=diff_method, argnums=[0]) + return qml.jvp(f_R1_to_R2, x, tangents, method=diff_method, argnums=[0]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_jvp_argument_type_checks_incompatible_input_shapes(diff_method): """Tests error handling of Catalyst's jvp when the shapes of the differentiable @@ -917,7 +932,7 @@ def test_jvp_argument_type_checks_incompatible_input_shapes(diff_method): """ with pytest.raises( - ValueError, match="catalyst.jvp called with different function params and tangent shapes" + ValueError, match="jvp called with different function params and tangent shapes" ): @qjit @@ -926,7 +941,7 @@ def C_workflow(): # but it has shape (4,) x = jnp.array([2.0, 3.0, 4.0]) tangents = jnp.ones([4], dtype=float) - return C_jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnums=[1]) + return qml.jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnums=[1]) @pytest.mark.usefixtures("use_both_frontend")