Source code for pygrog.interop._sigpy

"""sigpy Linop adapter for SparseFFT.

Wraps a pygrog operator (with ``forward`` / ``adjoint`` methods) as a
``sigpy.linop.Linop`` so it plugs into sigpy reconstruction algorithms
(conjugate gradient, primal-dual hybrid gradient, etc.) without modification.

sigpy ``Linop`` contract:
  - Subclass ``sigpy.linop.Linop``
  - Implement ``_apply(self, input)`` → applies the forward direction
  - Implement ``_adjoint_linop(self)`` → returns a ``Linop`` for the adjoint
  - Inputs/outputs are numpy (CPU) or cupy (GPU) arrays

Array conversion (numpy/cupy ↔ torch) is handled transparently by
:class:`~pygrog.operator.SparseFFT`, which is decorated with
:func:`mrinufft._array_compat.with_torch` on its ``forward`` and ``adjoint``
methods.  The sigpy ``_apply`` method therefore passes arrays directly to
the operator and receives numpy/cupy back.
"""

__all__ = [
    "GrogInterpolator",
    "GrogLinop",
    "GrogNormalLinop",
    "coil_compress",
    "nlinv_calib",
]

import numpy as np

from ..calib import GrogInterpolator as _GrogInterpolatorBase
from ..utils import coil_compress as _coil_compress
from ..utils import nlinv_calib as _nlinv_calib


[docs] class GrogLinop: """Wrap a pygrog SparseFFT-like operator as a ``sigpy.linop.Linop``. The returned object is a real ``sigpy.linop.Linop`` with a working ``.H`` (adjoint) property, and therefore participates in all sigpy operator algebra (composition via ``*``, addition, scaling, etc.). Parameters ---------- op : SparseFFT-like Any pygrog operator with ``forward(kspace) -> image`` and ``adjoint(image) -> kspace`` methods. Raises ------ ImportError If ``sigpy`` is not installed. Examples -------- :: from pygrog.operator import SparseFFT from pygrog.interop import GrogLinop base = SparseFFT(plan=grog.plan, smaps=smaps) A = GrogLinop(base) # Use inside sigpy CG reconstruction: import sigpy.alg as alg AHA = A.H * A # ... set up CG solver using AHA """ # --- class factory (lazy) --------------------------------------------- _sigpy_class = None def __new__(cls, op): if cls._sigpy_class is None: cls._sigpy_class = cls._build_sigpy_class() return cls._sigpy_class(op) @staticmethod def _build_sigpy_class(): try: from sigpy.linop import Linop except ImportError as exc: raise ImportError( "sigpy is required for GrogLinop. Install it with: pip install sigpy" ) from exc class _GrogLinopImpl(Linop): """sigpy Linop wrapping a pygrog SparseFFT-like operator.""" def __init__(self, op, *, _adjoint=False): self._op = op self._is_adjoint = _adjoint # Shapes for sigpy: oshape, ishape # Convention: forward = kspace → image (adjoint NUFFT direction) # ishape = (n_coils, n_samples) # oshape = image_shape (or (n_coils, *image_shape) without smaps) n_samples = op.indices.shape[0] n_coils = op.smaps.shape[0] if op.smaps is not None else 1 if op.smaps is not None: # forward: (n_coils, n_samples) → (*image_shape) ksp_shape = [n_coils, n_samples] img_shape = list(op.image_shape) else: ksp_shape = [n_coils, n_samples] img_shape = [n_coils, *list(op.image_shape)] if not _adjoint: super().__init__(img_shape, ksp_shape) else: super().__init__(ksp_shape, img_shape) def _apply(self, input): # noqa: A002 # SparseFFT.forward / .adjoint are decorated with with_torch, # so they accept numpy/cupy arrays and return the same type. if self._is_adjoint: # adjoint: image → k-space. Return flat (n_coils, n_samples) # to match the declared ishape/oshape so sigpy doesn't complain. op = self._op n_coils = op.smaps.shape[0] if op.smaps is not None else 1 n_samples = op.indices.shape[0] return op._adjoint_flat(input).reshape(n_coils, n_samples) return self._op.adjoint(input) def _adjoint_linop(self): return _GrogLinopImpl(self._op, _adjoint=not self._is_adjoint) return _GrogLinopImpl
# =========================================================================== # Normal-operator Linop (Toeplitz short-circuit) # =========================================================================== class GrogNormalLinop: """sigpy ``Linop`` wrapping ``op.normal`` (i.e. ``A^H A``). Use this in CG / least-squares loops in place of ``A.H * A`` when the underlying pygrog operator's ``toeplitz`` flag is enabled (``A.H * A`` in sigpy is a generic composed Linop and does NOT short-circuit to ``op.normal``). Parameters ---------- op : SparseFFT-like Any pygrog operator with a ``.normal(image)`` method. Examples -------- :: from pygrog.interop import GrogLinop, GrogNormalLinop A = GrogLinop(base) AHA = GrogNormalLinop(base) # uses Toeplitz when base.toeplitz=True # Use AHA wherever you would use ``A.H * A``. """ _sigpy_class = None def __new__(cls, op): if cls._sigpy_class is None: cls._sigpy_class = cls._build_sigpy_class() return cls._sigpy_class(op) @staticmethod def _build_sigpy_class(): try: from sigpy.linop import Linop except ImportError as exc: raise ImportError( "sigpy is required for GrogNormalLinop. " "Install it with: pip install sigpy" ) from exc class _GrogNormalLinopImpl(Linop): """Self-adjoint Linop that applies ``op.normal``.""" def __init__(self, op): self._op = op if op.smaps is not None: img_shape = list(op.image_shape) else: n_coils = op.smaps.shape[0] if op.smaps is not None else 1 img_shape = [n_coils, *list(op.image_shape)] super().__init__(img_shape, img_shape) def _apply(self, input): # noqa: A002 return self._op.normal(input) def _adjoint_linop(self): return self # self-adjoint return _GrogNormalLinopImpl # =========================================================================== # GrogInterpolator adapter (numpy ndarray I/O) # =========================================================================== class GrogInterpolator(_GrogInterpolatorBase): """GROG interpolator with numpy-array I/O for sigpy users. sigpy has no native non-Cartesian data container — calibration, coordinates and k-space are passed as plain :class:`numpy.ndarray`. This adapter is a thin sigpy-flavoured wrapper that mirrors the API of :class:`pygrog.interop.mrpro.GrogInterpolator` but accepts and returns sigpy's idiomatic types. Parameters ---------- coords : np.ndarray, shape ``(*spatial, ndim)`` Trajectory coordinates in pygrog scale (``[-shape/2, shape/2]``). shape : int | tuple[int, ...] Cartesian image shape spanned by the trajectory. image_shape : tuple[int, ...] | None, optional FFT crop target. Defaults to *shape*. kernel_width, oversamp, kernel_shape, time_map Forwarded to :class:`pygrog.calib.GrogInterpolator`. Returns from :meth:`interpolate` -------------------------------- Tuple ``(sparse_kspace_ndarray, GrogPlan)``. The sparse k-space has shape ``(n_coils, *natural_shape)`` (i.e. ``(n_coils, *spatial, kw)``) matching the layout :class:`~pygrog.operator.SparseFFT` consumes. """ def __init__( self, coords, shape, *, image_shape=None, kernel_width: int = 2, oversamp: float | list | tuple | None = None, kernel_shape: str = "circle", time_map=None, ): super().__init__( shape=shape, coords=np.asarray(coords), oversamp=oversamp, kernel_width=kernel_width, kernel_shape=kernel_shape, time_map=time_map, image_shape=image_shape, ) def interpolate(self, kspace, *, return_plan: bool = True, **kwargs): """Interpolate ``kspace`` and (optionally) return the plan. Parameters ---------- kspace : np.ndarray ``(*batch, n_coils, *spatial)`` complex k-space. return_plan : bool, optional If ``True`` (default) return ``(ndarray, GrogPlan)``; if ``False`` return the ndarray only. grid : bool, optional If ``True``, return ``(gridded_kspace, mask, density[, plan])`` numpy arrays instead of the flat sparse output. """ grid = kwargs.get("grid", False) out = super().interpolate(np.asarray(kspace), **kwargs) if grid: grid_kspace, masked_plan = out if return_plan: return np.asarray(grid_kspace), masked_plan, self.plan return np.asarray(grid_kspace), masked_plan out = np.asarray(out) # Reshape from flat (*batch, C, n_samples) → (*batch, C, *natural_shape) out = out.reshape(*out.shape[:-1], *self.plan.natural_shape) if return_plan: return out, self.plan return out # =========================================================================== # NLINV calibration — passthrough to pygrog.utils.nlinv_calib # =========================================================================== def nlinv_calib( kspace, coords, shape, *, cal_width: int = 24, ret_cal: bool = False, ret_image: bool = False, **kwargs, ): """Estimate coil sensitivities from non-Cartesian sigpy-style inputs. Thin numpy-friendly wrapper around :func:`pygrog.utils.nlinv_calib` (preferred over sigpy's :class:`sigpy.mri.app.JsenseRecon`). Parameters ---------- kspace : np.ndarray Multi-coil non-Cartesian k-space, shape ``(n_coils, n_samples)``. coords : np.ndarray Trajectory coordinates, shape ``(n_samples, ndim)`` in pygrog scale. shape : tuple[int, ...] Image shape ``(y, x)`` or ``(z, y, x)``. cal_width : int Calibration patch width. **kwargs Forwarded to :func:`pygrog.utils.nlinv_calib`. Returns ------- smaps : np.ndarray Coil sensitivities of shape ``(n_coils, *shape)``. *extras Optional ``(grappa_train, image)`` if requested via ``ret_cal=True`` / ``ret_image=True``. """ out = _nlinv_calib( np.asarray(kspace), cal_width=cal_width, shape=tuple(shape), coords=np.asarray(coords), ret_cal=ret_cal, ret_image=ret_image, **kwargs, ) if isinstance(out, tuple): return tuple(np.asarray(o) for o in out) return np.asarray(out) # =========================================================================== # Coil compression — passthrough to pygrog.utils.coil_compress # =========================================================================== def coil_compress( kspace, n_coils, *, traj=None, krad_thresh: float | None = None, ): """Coil-compress non-Cartesian k-space (numpy I/O). Thin wrapper around :func:`pygrog.utils.coil_compress`. Parameters ---------- kspace : np.ndarray Multi-coil k-space, shape ``(n_coils, n_samples)``. n_coils : int | float Number of virtual coils (int) or energy threshold (float in ``(0, 1]``). traj : np.ndarray, optional Sampling trajectory ``(n_samples, ndim)`` for radius-based calibration extraction. krad_thresh : float, optional Relative k-space radius threshold for calibration selection. Returns ------- compressed : np.ndarray Compressed k-space ``(n_virtual, n_samples)``. matrix : np.ndarray Compression matrix ``(n_virtual, n_coils)``. """ return _coil_compress( np.asarray(kspace), n_coils, traj=None if traj is None else np.asarray(traj), krad_thresh=krad_thresh, )