diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 8e075edb..827e7738 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -9266,6 +9266,600 @@ } } ], + "./arraycontext/linalg/solve.py": [ + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 46, + "endColumn": 53, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 46, + "endColumn": 58, + "lineCount": 1 + } + }, + { + "code": "reportCallIssue", + "range": { + "startColumn": 15, + "endColumn": 28, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 15, + "endColumn": 28, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 26, + "endColumn": 27, + "lineCount": 1 + } + }, + { + "code": "reportReturnType", + "range": { + "startColumn": 15, + "endColumn": 16, + "lineCount": 1 + } + }, + { + "code": "reportCallIssue", + "range": { + "startColumn": 15, + "endColumn": 38, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 15, + "endColumn": 38, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 23, + "endColumn": 37, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 27, + "endColumn": 36, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 9, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 31, + "endColumn": 34, + "lineCount": 1 + } + }, + { + "code": "reportAssignmentType", + "range": { + "startColumn": 31, + "endColumn": 34, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 16, + "endColumn": 29, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 22, + "endColumn": 23, + "lineCount": 1 + } + }, + { + "code": "reportPossiblyUnboundVariable", + "range": { + "startColumn": 22, + "endColumn": 23, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 21, + "endColumn": 22, + "lineCount": 1 + } + }, + { + "code": "reportPossiblyUnboundVariable", + "range": { + "startColumn": 21, + "endColumn": 22, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 19, + "endColumn": 20, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 19, + "endColumn": 20, + "lineCount": 1 + } + }, + { + "code": "reportPossiblyUnboundVariable", + "range": { + "startColumn": 19, + "endColumn": 20, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 10, + "lineCount": 1 + } + }, + { + "code": "reportPossiblyUnboundVariable", + "range": { + "startColumn": 13, + "endColumn": 14, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 24, + "endColumn": 29, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 16, + "endColumn": 17, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 20, + "endColumn": 33, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 24, + "endColumn": 33, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 16, + "endColumn": 18, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 21, + "endColumn": 34, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 26, + "endColumn": 34, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 23, + "endColumn": 24, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 16, + "endColumn": 19, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 14, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 17, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportCallIssue", + "range": { + "startColumn": 8, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 8, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportCallIssue", + "range": { + "startColumn": 8, + "endColumn": 12, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 8, + "endColumn": 12, + "lineCount": 1 + } + }, + { + "code": "reportPossiblyUnboundVariable", + "range": { + "startColumn": 23, + "endColumn": 24, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 13, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 16, + "endColumn": 27, + "lineCount": 1 + } + }, + { + "code": "reportPossiblyUnboundVariable", + "range": { + "startColumn": 16, + "endColumn": 17, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 20, + "endColumn": 27, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 9, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 12, + "endColumn": 22, + "lineCount": 1 + } + }, + { + "code": "reportAssignmentType", + "range": { + "startColumn": 12, + "endColumn": 22, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 16, + "endColumn": 22, + "lineCount": 1 + } + }, + { + "code": "reportAssignmentType", + "range": { + "startColumn": 28, + "endColumn": 43, + "lineCount": 1 + } + }, + { + "code": "reportAttributeAccessIssue", + "range": { + "startColumn": 29, + "endColumn": 42, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 16, + "lineCount": 1 + } + }, + { + "code": "reportCallIssue", + "range": { + "startColumn": 19, + "endColumn": 60, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 27, + "endColumn": 59, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 55, + "endColumn": 59, + "lineCount": 1 + } + }, + { + "code": "reportAssignmentType", + "range": { + "startColumn": 24, + "endColumn": 68, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 4, + "endColumn": 9, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 21, + "endColumn": 29, + "lineCount": 1 + } + }, + { + "code": "reportAttributeAccessIssue", + "range": { + "startColumn": 24, + "endColumn": 29, + "lineCount": 1 + } + }, + { + "code": "reportAttributeAccessIssue", + "range": { + "startColumn": 24, + "endColumn": 29, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 26, + "endColumn": 31, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 24, + "endColumn": 33, + "lineCount": 1 + } + }, + { + "code": "reportAttributeAccessIssue", + "range": { + "startColumn": 27, + "endColumn": 33, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 24, + "endColumn": 35, + "lineCount": 1 + } + }, + { + "code": "reportAttributeAccessIssue", + "range": { + "startColumn": 27, + "endColumn": 35, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 40, + "endColumn": 45, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 8, + "endColumn": 15, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 16, + "endColumn": 23, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 16, + "endColumn": 23, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 25, + "endColumn": 32, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 25, + "endColumn": 32, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 8, + "endColumn": 15, + "lineCount": 1 + } + } + ], "./arraycontext/loopy.py": [ { "code": "reportUnknownMemberType", @@ -11762,6 +12356,30 @@ "endColumn": 55, "lineCount": 1 } + }, + { + "code": "reportUnknownLambdaType", + "range": { + "startColumn": 20, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportFunctionMemberAccess", + "range": { + "startColumn": 11, + "endColumn": 16, + "lineCount": 1 + } + }, + { + "code": "reportFunctionMemberAccess", + "range": { + "startColumn": 11, + "endColumn": 16, + "lineCount": 1 + } } ], "./test/testlib.py": [ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c6e60e75..ca37a3e4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: - uses: actions/checkout@v6 - name: "Main Script" run: | - EXTRA_INSTALL="pytest types-colorama types-Pygments scipy-stubs" + EXTRA_INSTALL="pytest types-colorama types-Pygments scipy-stubs matplotlib" curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 diff --git a/arraycontext/linalg/__init__.py b/arraycontext/linalg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arraycontext/linalg/solve.py b/arraycontext/linalg/solve.py new file mode 100644 index 00000000..36e88b6a --- /dev/null +++ b/arraycontext/linalg/solve.py @@ -0,0 +1,451 @@ +""" +.. autofunction:: gmres + +.. autoclass:: GMRESResult +.. autoexception:: GMRESError +.. autoclass:: ResidualPrinter + +.. autoclass:: InnerProduct + :members: + :undoc-members: + :special-members: __call__ +.. autoclass:: CallableOperator + :members: + :undoc-members: + :special-members: __call__ +.. autoclass:: HasMatVec + :members: + :undoc-members: +""" + +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2012-2013 Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, Generic, Protocol + +import numpy as np + +from pytools import T + +from arraycontext import ArrayContext, ArrayOrContainerT + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + import optype.numpy as onp + + +# {{{ gmres + +# Modified Python port of ./Apps/Acoustics/root/matlab/gmres_restart.m +# from hellskitchen. +# Necessary because SciPy gmres is not reentrant and thus does +# not allow recursive solves. + + +class InnerProduct(Protocol, Generic[T]): + """A :class:`~typing.Protocol` for the inner product used by :func:`gmres`.""" + + def __call__(self, a: T, b: T) -> T: ... + + +class CallableOperator(Protocol, Generic[T]): + """A :class:`~typing.Protocol` for the operator used by :func:`gmres`.""" + + @property + def shape(self) -> tuple[int, int]: ... + + def __call__(self, x: T) -> T: ... + + +class HasMatVec(Protocol, Generic[T]): + """A :class:`~typing.Protocol` for the operator used by :func:`gmres`.""" + + @property + def shape(self) -> tuple[int, int]: ... + + def matvec(self, x: T) -> T: ... + + +def structured_vdot(x: ArrayOrContainerT, y: ArrayOrContainerT, + array_context: ArrayContext | None = None) -> float: + """vdot() implementation that is aware of scalars and host or + PyOpenCL arrays. It also recurses down nested object arrays. + """ + + if type(x) is not type(y): + raise TypeError("'structured_vdot' entries have different types: " + f"{type(x).__name__} and {type(y).__name__}") + + from numbers import Number + if (isinstance(x, Number) + or (isinstance(x, np.ndarray) and x.dtype.char != "O")): + return np.vdot(x, y) + else: + if array_context is None: + raise ValueError("'array_context' is required for non-scalar inputs") + + # actx.np.vdot works on PyOpenCL arrays and arbitrarily nested + # array containers, so this should handle all remaining cases + r = array_context.to_numpy(array_context.np.vdot(x, y)) + if isinstance(r, np.ndarray) and r.shape == (): + r = r[()] + + return r + + +class GMRESError(RuntimeError): + pass + + +# {{{ main routine + +@dataclass(frozen=True) +class GMRESResult(Generic[T]): + """ + .. autoattribute:: solution + .. autoattribute:: residual_norms + .. autoattribute:: iteration_count + .. autoattribute:: success + .. autoattribute:: state + """ + + solution: T + residual_norms: Sequence[float] + iteration_count: int + success: bool + """A :class:`bool` indicating whether the iteration succeeded.""" + state: str + """A description of the outcome.""" + + +def _gmres( + a: CallableOperator[ArrayOrContainerT] | HasMatVec[ArrayOrContainerT], + b: ArrayOrContainerT, + restart: int | None = None, + tol: float | None = None, + x0: ArrayOrContainerT | None = None, + dot: InnerProduct[ArrayOrContainerT] | None = None, + maxiter: int | None = None, + hard_failure: bool | None = None, + require_monotonicity: bool = True, + no_progress_factor: float | None = None, + stall_iterations: int | None = None, + callback: Callable[[ArrayOrContainerT], None] | None = None + ) -> GMRESResult[ArrayOrContainerT]: + + # {{{ input processing + + n, _ = a.shape + a_call = a.matvec if not callable(a) else a + + if dot is None: + raise ValueError("'dot' not provided") + + if restart is None: + restart = min(n, 20) + + if tol is None: + tol = 1e-5 + + if maxiter is None: + maxiter = 2*n + + if hard_failure is None: + hard_failure = True + + if stall_iterations is None: + stall_iterations = 10 + + if no_progress_factor is None: + no_progress_factor = 1.25 + + # }}} + + def norm(x: ArrayOrContainerT) -> float: + return np.sqrt(abs(dot(x, x))) + + if x0 is None: + x: ArrayOrContainerT = 0*b + r = b + recalc_r = False + else: + x = x0 + del x0 + recalc_r = True + + ae: list[ArrayOrContainerT | None] = [None]*restart + e: list[ArrayOrContainerT | None] = [None]*restart + + k = 0 + + norm_b = norm(b) + last_resid_norm = None + residual_norms: list[float] = [] + + iteration = 0 + for iteration in range(maxiter): + # restart if required + if k == restart: + k = 0 + orth_count = restart + else: + orth_count = k + + # recalculate residual every 10 steps + if recalc_r: + r = b - a_call(x) + + norm_r = norm(r) + residual_norms.append(norm_r) + + if callback is not None: + callback(r) + + if norm_r < tol*norm_b or norm_r == 0: + return GMRESResult( + solution=x, + residual_norms=residual_norms, + iteration_count=iteration, + success=True, + state="success") + if last_resid_norm is not None: + if norm_r > 1.25*last_resid_norm: + state = "non-monotonic residuals" + if require_monotonicity: + if hard_failure: + raise GMRESError(state) + else: + return GMRESResult( + solution=x, + residual_norms=residual_norms, + iteration_count=iteration, + success=False, + state=state) + else: + print("*** WARNING: non-monotonic residuals in GMRES") + + if (stall_iterations + and len(residual_norms) > stall_iterations + and norm_r > ( + residual_norms[-stall_iterations] + / no_progress_factor)): + + state = "stalled" + if hard_failure: + raise GMRESError(state) + else: + return GMRESResult( + solution=x, + residual_norms=residual_norms, + iteration_count=iteration, + success=False, + state=state) + + last_resid_norm = norm_r + + # initial new direction guess + w = a_call(r) + + # {{{ double-orthogonalize the new direction against preceding ones + + rp = r + + for _orth_trips in range(2): + for j in range(orth_count): + d = dot(ae[j], w) + w = w - d * ae[j] + rp = rp - d * e[j] + + # normalize + d = 1/norm(w) + w = d*w + rp = d*rp + + # }}} + + ae[k] = w + e[k] = rp + + # update the residual and solution + d = dot(ae[k], r) + + recalc_r = (iteration+1) % 10 == 0 + if not recalc_r: + r = r - d*ae[k] + + x = x + d*e[k] + + k += 1 + + state = "max iterations" + if hard_failure: + raise GMRESError(state) + else: + return GMRESResult( + solution=x, + residual_norms=residual_norms, + iteration_count=iteration, + success=False, + state=state) + +# }}} + + +# {{{ progress reporting + +class ResidualPrinter(Generic[ArrayOrContainerT]): + count: int + inner_product: InnerProduct[ArrayOrContainerT] + + def __init__( + self, + inner_product: InnerProduct[ArrayOrContainerT] | None = None + ) -> None: + if inner_product is None: + inner_product = structured_vdot + + self.count = 0 + self.inner_product = inner_product + + def __call__(self, resid: ArrayOrContainerT | None) -> None: + import sys + if resid is not None: + norm = np.sqrt(self.inner_product(resid, resid)) + sys.stdout.write(f"IT {self.count:8d} {abs(norm):.8e}\n") + else: + sys.stdout.write(f"IT {self.count:8d}\n") + + self.count += 1 + sys.stdout.flush() + +# }}} + + +# {{{ entrypoint + +def gmres( + op: CallableOperator[ArrayOrContainerT] | HasMatVec[ArrayOrContainerT], + rhs: ArrayOrContainerT, + *, + restart: int | None = None, + tol: float | None = None, + x0: ArrayOrContainerT | None = None, + inner_product: InnerProduct[ArrayOrContainerT] | None = None, + maxiter: int | None = None, + hard_failure: bool | None = None, + no_progress_factor: float | None = None, + stall_iterations: int | None = None, + callback: Callable[[ArrayOrContainerT], None] | None = None, + progress: bool = False, + require_monotonicity: bool = True, + actx: ArrayContext | None = None, + ) -> GMRESResult[ArrayOrContainerT]: + """Solve a linear system :math:`Ax = b` using GMRES with restarts. + + :arg op: a callable to evaluate :math:`A(x)`. + :arg rhs: the right hand side :math:`b`. + :arg restart: the maximum number of iteration after which GMRES algorithm + needs to be restarted + :arg tol: the required decrease in residual norm (relative to the *rhs*). + :arg x0: an initial guess for the iteration (a zero array is used by default). + :arg inner_product: a callable with an interface compatible with + :func:`numpy.vdot` that returns a host scalar. + :arg maxiter: the maximum number of iterations permitted. + :arg hard_failure: if *True*, raise :exc:`GMRESError` in case of failure. + :arg stall_iterations: number of iterations with residual decrease + below *no_progress_factor* indicates stall. Set to ``0`` to disable + stall detection. + """ + if inner_product is None: + if actx is None: + raise TypeError("actx is required if inner_product is not supplied") + inner_product = partial(structured_vdot, array_context=actx) + + if callback is None: + callback = ResidualPrinter(inner_product) if progress else None + + return _gmres(op, rhs, restart=restart, tol=tol, x0=x0, + dot=inner_product, + maxiter=maxiter, hard_failure=hard_failure, + no_progress_factor=no_progress_factor, + stall_iterations=stall_iterations, callback=callback, + require_monotonicity=require_monotonicity) + + +# }}} + + +def build_matrix( + op: CallableOperator[onp.Array1D] | HasMatVec[onp.Array1D], + dtype: np.dtype[Any] | None = None, + shape: tuple[int, int] | None = None + ): + dtype = dtype or op.dtype + from pytools import ProgressBar + shape = shape or op.shape + _rows, cols = shape + pb = ProgressBar("matrix", cols) + mat = np.zeros(shape, dtype) + + try: + matvec_method = op.matvec + except AttributeError: + matvec_method = op.__call__ + + for i in range(cols): + unit_vec = np.zeros(cols, dtype=dtype) + unit_vec[i] = 1 + mat[:, i] = matvec_method(unit_vec) + pb.progress() + + pb.finished() + + return mat + + +# {{{ direct solve + +def lu( + op: CallableOperator[onp.Array1D], + rhs: onp.Array1D, + show_spectrum: bool = False): + import numpy.linalg as la + + mat = build_matrix(op) + + print(f"condition number: {la.cond(mat)}") + if show_spectrum: + ev = la.eigvals(mat) + import matplotlib.pyplot as pt + pt.plot(ev.real, ev.imag, "o") + pt.show() + + return la.solve(mat, rhs) + +# }}} diff --git a/doc/conf.py b/doc/conf.py index fe998eed..925128d2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -45,11 +45,20 @@ ["py:class", r"arraycontext.typing._UserDefinedArithArrayContainer"], ["py:class", r"np.integer"], ["py:class", r".*\|.*"], + ["py:data", r"types.EllipsisType"], ] sphinxconfig_missing_reference_aliases = { # pyopencl "cl.Device": "obj:pyopencl.Device", + "cl.Allocator": "obj:pyopencl.array.Allocator", + "np.ndarray": "obj:numpy.ndarray", + "ToTagSetConvertible": "obj:pytools.tag.ToTagSetConvertible", + "ArrayOrNames": "obj:pytato.ArrayOrNames", + "Integer": "obj:python.int", + "ScalarLike": "obj:arraycontext.ScalarLike", + "ArrayOrContainerOrScalar": "obj:arraycontext.ArrayOrContainerOrScalar", + "arraycontext.typing.ArrayOrContainerT": "obj:arraycontext.ArrayOrContainerT", } diff --git a/doc/other.rst b/doc/other.rst index 4b0cca75..6f6b3d66 100644 --- a/doc/other.rst +++ b/doc/other.rst @@ -10,6 +10,11 @@ Metadata ("tags") for Arrays and Array Axes .. automodule:: arraycontext.metadata +Linear system solving +--------------------- + +.. automodule:: arraycontext.linalg.solve + :class:`~arraycontext.ArrayContext`-generating fixture for :mod:`pytest` ------------------------------------------------------------------------ @@ -19,42 +24,3 @@ Program creation for :mod:`loopy` --------------------------------- .. automodule:: arraycontext.loopy - -References ----------- - -.. currentmodule:: cl_array - -.. class:: Allocator - - See :class:`pyopencl.array.Allocator`. - -.. currentmodule:: np - -.. class:: ndarray - - See :class:`numpy.ndarray`. - -.. currentmodule:: dummy_refs - -.. class:: ToTagSetConvertible - - See :mod:`pytools.tag`. - -.. class:: ArrayOrNames - - A type alias in :mod:`pytato` allowing - :class:`pytato.Array` and - :class:`pytato.AbstractResultWithNamedArrays`. - -.. class:: Integer - - A type alias allowing integers. - -.. class:: ScalarLike - - See :class:`arraycontext.ScalarLike`. - -.. class:: ArrayOrContainerOrScalar - - See :class:`arraycontext.ArrayOrContainerOrScalar`. diff --git a/test/test_utils.py b/test/test_utils.py index f58432f1..32e312e1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -39,6 +39,7 @@ ) import numpy as np +import numpy.linalg as la import pytest @@ -249,6 +250,35 @@ class SomeOtherContainer: # }}} +# {{{ test_gmres + +def test_gmres(): + rng = np.random.default_rng(seed=42) + + n = 200 + a = ( + n * (np.eye(n) + 2j * np.eye(n)) + + rng.normal(size=(n, n)) + 1j * rng.normal(size=(n, n))) + + true_sol = rng.normal(size=n) + 1j * rng.normal(size=n) + b = np.dot(a, true_sol) + + A_func = lambda x: np.dot(a, x) # noqa + A_func.shape = a.shape + A_func.dtype = a.dtype + + from arraycontext.linalg.solve import ResidualPrinter, gmres + tol = 1e-6 + sol = gmres(A_func, b, callback=ResidualPrinter(), + maxiter=5*n, tol=tol, + inner_product=np.vdot, + ).solution + + assert la.norm(true_sol - sol) / la.norm(sol) < tol + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: