Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ jobs:

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-jvp


- name: Get Cached LLVM Build
Expand Down Expand Up @@ -576,7 +576,7 @@ jobs:

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-jvp

- name: Get Cached LLVM Build
id: cache-llvm-build
Expand Down Expand Up @@ -642,7 +642,7 @@ jobs:

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-jvp

- name: Get Cached LLVM Build
id: cache-llvm-build
Expand Down
22 changes: 11 additions & 11 deletions frontend/catalyst/api_extensions/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import copy
import functools
import numbers
from typing import Callable, Iterable, List, Optional, Union
from typing import Callable, Sequence, List, Optional, Union

import jax
import pennylane as qml
Expand Down Expand Up @@ -442,12 +442,12 @@ def workflow(primals, tangents):
(Array(0.78766064, dtype=float64), Array(-0.7011436, dtype=float64))
"""

def check_is_iterable(x, hint):
if not isinstance(x, Iterable):
raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}")
def check_is_Sequence(x, hint):
if not isinstance(x, Sequence):
raise ValueError(f"vjp '{hint}' argument must be a Sequence, not {type(x)}")

check_is_iterable(params, "params")
check_is_iterable(tangents, "tangents")
check_is_Sequence(params, "params")
check_is_Sequence(tangents, "tangents")

if EvaluationContext.is_tracing():
scalar_out = False
Expand Down Expand Up @@ -550,12 +550,12 @@ def f(x):
if qml.capture.enabled():
return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums)

def check_is_iterable(x, hint):
if not isinstance(x, Iterable):
raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}")
def check_is_Sequence(x, hint):
if not isinstance(x, Sequence):
raise ValueError(f"vjp '{hint}' argument must be a Sequence, not {type(x)}")

check_is_iterable(params, "params")
check_is_iterable(cotangents, "cotangents")
check_is_Sequence(params, "params")
check_is_Sequence(cotangents, "cotangents")

if EvaluationContext.is_tracing():
scalar_out = False
Expand Down
14 changes: 14 additions & 0 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import transform_prim
from pennylane.capture.primitives import vjp_prim as pl_vjp_prim
from pennylane.capture.primitives import jvp_prim as pl_jvp_prim
from pennylane.transforms import commute_controlled as pl_commute_controlled
from pennylane.transforms import decompose as pl_decompose
from pennylane.transforms import gridsynth as pl_gridsynth
Expand Down Expand Up @@ -252,6 +253,19 @@ def handle_vjp(self, *args, jaxpr, **kwargs):
return pl_vjp_prim.bind(*new_args, jaxpr=new_j, **kwargs)


@WorkflowInterpreter.register_primitive(pl_jvp_prim)
def handle_jvp(self, *args, jaxpr, **kwargs):
"""Translate a grad equation."""
f = partial(copy(self).eval, jaxpr, [])
new_jaxpr = jax.make_jaxpr(f)(*args[: len(jaxpr.invars)])

new_args = (*new_jaxpr.consts, *args)
j = new_jaxpr.jaxpr
new_j = j.replace(constvars=(), invars=j.constvars + j.invars)
return pl_jvp_prim.bind(*new_args, jaxpr=new_j, **kwargs)



# pylint: disable=unused-argument, too-many-arguments
@WorkflowInterpreter.register_primitive(qnode_prim)
def handle_qnode(
Expand Down
35 changes: 35 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@

from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import vjp_prim as pl_vjp_prim
from pennylane.capture.primitives import jvp_prim as pl_jvp_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
Expand Down Expand Up @@ -875,6 +876,39 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params):
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results

def _capture_jvp_lowering(ctx, *args, jaxpr, fn, method, argnums, h):
"""
Returns:
MLIR results
"""
args = list(args)
mlir_ctx = ctx.module_context.context
n_params = len(jaxpr.invars)
array_argnums = np.array(argnums)

output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)

func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)]
jvp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :]

func_args = args[:n_params]
d_args = args[n_params:]

func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn)

symbol_ref = get_symbolref(ctx, func_op)
return JVPOp(
func_result_types,
jvp_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(d_args),
diffArgIndices=ir.DenseIntElementsAttr.get(array_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results


@vjp_p.def_impl
def _vjp_def_impl(ctx, *args, jaxpr, fn, grad_params): # pragma: no cover
Expand Down Expand Up @@ -2907,6 +2941,7 @@ def subroutine_lowering(*args, **kwargs):
(grad_p, _grad_lowering),
(pl_jac_prim, _capture_grad_lowering),
(pl_vjp_prim, _capture_vjp_lowering),
(pl_jvp_prim, _capture_jvp_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
Loading
Loading