Source code for pygrog.interop._mrpro

"""mrpro LinearOperator adapter for SparseFFT.

Wraps :class:`~pygrog.operator.SparseFFT` (or any pygrog operator with
``forward`` / ``adjoint`` methods) as an ``mrpro.operators.LinearOperator``
so it plugs into mrpro reconstruction pipelines and algorithms (PGD, CG,
PDHG, …) without modification.

Shape conventions (matching mrpro's k-space layout
``(..., coils, k2, k1, k0)``)
==========================================================================

The adapter speaks mrpro's native shapes at its boundary and handles the
GROG kernel-width axis (``kw``) internally so callers never need to
introduce a manual rearrangement.  For an operator whose underlying GROG
plan has ``natural_shape == (*trajectory_shape, kw)``:

* **k-space**: ``(*other, n_coils, *trajectory_shape[:-1], trajectory_shape[-1] * kw)``
  i.e. the kernel-width axis is fused into the last trajectory axis (k0).
* **image**:   ``(*other, *image_shape)`` — no coil axis when the wrapped
  operator carries coil sensitivities (smaps), since coils are already
  combined.

Gradients are computed via :mod:`pygrog.interop._torch` — explicit
``torch.autograd.Function`` subclasses that are also compatible with
``torch.func.grad`` / ``vmap`` (used internally by ``mrpro.algorithms``).
``adjoint_as_backward`` is intentionally *not* used.

mrpro ``LinearOperator`` contract (this adapter's convention):
  - ``forward(kspace) -> (image,)``  — backprojection (A^H)
  - ``adjoint(image)  -> (kspace,)`` — measurement   (A)
"""

__all__ = ["GrogInterpolator", "GrogLinearOp", "coil_compress", "nlinv_calib"]

import numpy as np
import torch

from ..calib import GrogInterpolator as _GrogInterpolatorBase
from ..utils import nlinv_calib as _nlinv_calib
from ._torch import grog_backproject, grog_measure


# ---------------------------------------------------------------------------
# KData ↔ pygrog field extraction helpers
# ---------------------------------------------------------------------------
def _kdata_extract(kdata):
    """Pull coords / k-space / shapes out of a mrpro ``KData``.

    Returns
    -------
    coords : np.ndarray, shape ``(*spatial, ndim)``
        Trajectory coordinates in pygrog scale (``[-shape/2, shape/2]``,
        where ``shape == encoding_matrix``).  ``*spatial`` matches the
        trajectory's ``(k2, k1, k0)`` axes (broadcast to dense if needed).
    data : torch.Tensor, shape ``(*other, n_coils, *spatial)``
        K-space data, complex.  ``*spatial`` matches *coords*.
    enc_shape : tuple[int, ...]
        Encoding-matrix shape in pygrog ordering (``(y, x)`` for 2D,
        ``(z, y, x)`` for 3D).
    recon_shape : tuple[int, ...]
        Recon-matrix shape (image FFT crop target).
    """
    traj = kdata.traj
    # kx/ky/kz: (*other, 1, k2, k1, k0).  Drop the coil-singleton axis.
    kx = traj.kx.squeeze(-4)
    ky = traj.ky.squeeze(-4)
    kz = traj.kz.squeeze(-4)

    # 2D vs 3D detection: if kz is identically zero and k2==1, treat as 2D.
    k2 = kdata.data.shape[-3]
    is_2d = (k2 == 1) and bool(torch.all(kz == 0).item())

    # Broadcast traj components to the data's spatial shape (k2, k1, k0).
    spatial = tuple(int(s) for s in kdata.data.shape[-3:])

    # Each component may have leading "other"/singleton dims; collapse to spatial.
    def _to_spatial(t):
        # Broadcast against an all-ones tensor of shape (*spatial,) so any
        # leading repetition or singleton broadcasts cleanly.
        if t.ndim < 3:
            t = t.reshape((1,) * (3 - t.ndim) + tuple(t.shape))
        # Take the trailing 3 dims and broadcast to spatial.
        target = torch.broadcast_to(t, t.shape[:-3] + spatial)
        # If there are leading "other" dims, take the first slice — the
        # trajectory is assumed identical across "other".
        while target.ndim > 3:
            target = target[0]
        return target

    kx_s = _to_spatial(kx)
    ky_s = _to_spatial(ky)
    kz_s = _to_spatial(kz)

    if is_2d:
        # 2D: (k1, k0, 2) in (y, x) order; drop the singleton k2 axis.
        coords = torch.stack([ky_s[0], kx_s[0]], dim=-1)
        enc_shape = (
            int(kdata.header.encoding_matrix.y),
            int(kdata.header.encoding_matrix.x),
        )
        recon_shape = (
            int(kdata.header.recon_matrix.y),
            int(kdata.header.recon_matrix.x),
        )
    else:
        coords = torch.stack([kz_s, ky_s, kx_s], dim=-1)
        enc_shape = (
            int(kdata.header.encoding_matrix.z),
            int(kdata.header.encoding_matrix.y),
            int(kdata.header.encoding_matrix.x),
        )
        recon_shape = (
            int(kdata.header.recon_matrix.z),
            int(kdata.header.recon_matrix.y),
            int(kdata.header.recon_matrix.x),
        )

    return (
        coords.detach().cpu().numpy().astype(np.float32),
        kdata.data,
        enc_shape,
        recon_shape,
    )


def _data_to_spatial(data, ndim):
    """Reshape KData ``(..., coils, k2, k1, k0)`` for pygrog interpolate.

    For 2D (``ndim==2``), drops the singleton k2 axis →
    ``(..., coils, k1, k0)``.  For 3D, leaves shape unchanged.
    """
    if ndim == 2:
        # k2 (axis -3) must be 1.
        assert data.shape[-3] == 1, (
            f"2D interpolation expects k2 == 1; got {tuple(data.shape)}"
        )
        new_shape = data.shape[:-3] + data.shape[-2:]
        return data.reshape(new_shape)
    return data


def _spatial_to_kdata(arr, ndim):
    """Inverse of :func:`_data_to_spatial`: re-insert the singleton k2 axis."""
    if ndim == 2:
        new_shape = (*arr.shape[:-2], 1, *arr.shape[-2:])
        return arr.reshape(new_shape)
    return arr


[docs] class GrogLinearOp: """Wrap a pygrog operator as an ``mrpro.operators.LinearOperator``. Because mrpro is an optional dependency, the class is built lazily on first instantiation so importing this module does not fail when mrpro is absent. Parameters ---------- op : SparseFFT-like Any operator with ``forward(kspace) -> image`` and ``adjoint(image) -> kspace`` methods, plus a ``natural_shape`` attribute whose last axis is the GROG kernel width. Raises ------ ImportError If ``mrpro`` is not installed. Examples -------- :: from pygrog.operator import SparseFFT from pygrog.interop import GrogLinearOp from mrpro.algorithms.optimizers import pgd base = SparseFFT(plan=grog.plan, smaps=smaps) mrpro_op = GrogLinearOp(base) # mrpro_op now consumes/produces standard # ``(*other, n_coils, k2, k1, k0)`` k-space and ``(*other, z, y, x)`` # images, ready for any mrpro algorithm. """ _mrpro_class = None # cached at class level def __new__(cls, op): if cls._mrpro_class is None: cls._mrpro_class = cls._build_mrpro_class() return cls._mrpro_class(op) @staticmethod def _build_mrpro_class(): try: from mrpro.operators import LinearOperator except ImportError as exc: raise ImportError( "mrpro is required for GrogLinearOp. " "Install it with: pip install mrpro" ) from exc class _GrogLinearOpImpl(LinearOperator): """mrpro LinearOperator wrapping a pygrog SparseFFT-like operator. Hides the GROG kernel-width axis by fusing it into the last trajectory axis (k0), so the public shape contract matches mrpro's standard ``(..., coils, k2, k1, k0)`` k-space layout. """ def __init__(self, op): super().__init__() self._op = op # natural_shape == (*trajectory_shape, kw). self._nat_shape = tuple(int(s) for s in op.natural_shape) if len(self._nat_shape) >= 2: *traj, kw = self._nat_shape self._mrpro_kshape = (*traj[:-1], traj[-1] * kw) else: # 1-D natural shape: kw is the only axis and is left as-is. self._mrpro_kshape = self._nat_shape def _to_natural(self, x: torch.Tensor) -> torch.Tensor: """Public mrpro k-shape → internal natural shape.""" lead = x.shape[: -len(self._mrpro_kshape)] return x.reshape(*lead, *self._nat_shape) def _to_mrpro(self, x: torch.Tensor) -> torch.Tensor: """Internal natural shape → public mrpro k-shape.""" lead = x.shape[: -len(self._nat_shape)] return x.reshape(*lead, *self._mrpro_kshape) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Backprojection: k-space → image (A^H). Parameters ---------- x : torch.Tensor K-space tensor with shape ``(*other, n_coils, *trajectory_shape[:-1], trajectory_shape[-1] * kw)``. Returns ------- tuple of torch.Tensor Image tensor with shape ``(*other, *image_shape)``. """ x = self._to_natural(x) return (grog_backproject(x, self._op),) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Measurement: image → k-space (A). Parameters ---------- x : torch.Tensor Image tensor with shape ``(*other, *image_shape)``. Returns ------- tuple of torch.Tensor K-space tensor in mrpro layout ``(*other, n_coils, *trajectory_shape[:-1], trajectory_shape[-1] * kw)``. """ ksp = grog_measure(x, self._op) return (self._to_mrpro(ksp),) @property def H(self): """Adjoint LinearOperator with Toeplitz-aware ``.gram``. In mrpro convention, ``self.forward`` is the backprojection (k→image) so ``self.H`` represents the acquisition (image→k). We override the adjoint's ``gram`` (which mrpro defines as ``self.H @ self``) to short-circuit ``self.H @ self.H.H = adjoint @ forward = pygrog op.normal`` (image → image), enabling Toeplitz acceleration when the underlying op has ``toeplitz=True``. """ outer = self return _make_acq_op(outer) return _GrogLinearOpImpl
def _make_acq_op(outer): """Build a Toeplitz-aware adjoint LinearOperator wrapping ``outer.H``.""" from mrpro.operators import LinearOperator class _GrogAcqOp(LinearOperator): """Acquisition op (image→k) with Toeplitz-accelerated ``gram``.""" def __init__(self): super().__init__() self._outer = outer def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: # forward of acquisition = adjoint of original = pygrog adjoint return self._outer.adjoint(x) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: return self._outer.forward(x) @property def gram(self): """``self.H @ self`` short-circuit using ``op.normal``.""" outer_ref = self._outer return _make_image_normal_op(outer_ref) return _GrogAcqOp() def _make_image_normal_op(outer): """Self-adjoint LinearOperator computing ``op.normal`` on images.""" from mrpro.operators import LinearOperator class _GrogImageNormalOp(LinearOperator): """Image-domain ``A^H A`` via ``pygrog op.normal``.""" def __init__(self): super().__init__() self._outer = outer def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: return (self._outer._op.normal(x),) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: # Self-adjoint. return (self._outer._op.normal(x),) @property def gram(self): return self # already a normal operator @property def H(self): return self # self-adjoint return _GrogImageNormalOp() # =========================================================================== # GrogInterpolator adapter (KData I/O) # =========================================================================== class GrogInterpolator(_GrogInterpolatorBase): """GROG interpolator with native :class:`mrpro.data.KData` I/O. Extracts the k-space trajectory and encoding/recon matrix sizes from ``kdata`` and configures the underlying pygrog :class:`pygrog.calib.GrogInterpolator` accordingly. Calibration may be supplied as a calibration ``KData`` (for instance, a Cartesian centre block) or as a raw numpy/tensor patch. Parameters ---------- kdata : mrpro.data.KData Source k-space whose trajectory drives the GROG plan. Header ``encoding_matrix`` defines the Cartesian grid the trajectory spans; ``recon_matrix`` is the post-FFT crop. kernel_width, oversamp, kernel_shape, time_map Forwarded to :class:`pygrog.calib.GrogInterpolator`. Notes ----- The result of :meth:`interpolate` is a fresh ``KData`` whose ``data`` carries the sparse Cartesian k-space (the GROG kernel-width axis is fused into ``k0``, matching :class:`GrogLinearOp`'s convention) and whose ``traj`` is updated to the gridded sample locations. The companion :class:`pygrog.calib.GrogPlan` is returned alongside so callers can build a downstream :class:`~pygrog.operator.SparseFFT` / :class:`GrogLinearOp` directly. """ def __init__( self, kdata, *, kernel_width: int = 2, oversamp: float | list | tuple | None = None, kernel_shape: str = "circle", time_map=None, ): coords_np, _, enc_shape, recon_shape = _kdata_extract(kdata) super().__init__( shape=enc_shape, coords=coords_np, oversamp=oversamp, kernel_width=kernel_width, kernel_shape=kernel_shape, time_map=time_map, image_shape=recon_shape, ) self._enc_shape = enc_shape self._recon_shape = recon_shape self._ndim = len(enc_shape) # ------------------------------------------------------------------ def calc_interp_table(self, calib, *, lamda: float = 0.01, precision: int = 1): """Fit the GRAPPA kernel table. Parameters ---------- calib : KData | np.ndarray | torch.Tensor Fully-sampled calibration region. When a ``KData`` is given, its ``data`` tensor is used; the leading ``other`` dims are squeezed and any singleton k2 axis is dropped for 2D plans. """ try: from mrpro.data import KData as _KData except ImportError: _KData = None if _KData is not None and isinstance(calib, _KData): arr = calib.data # Squeeze leading "other" dims (assumed all size 1 for calibration). while arr.ndim > 3 + 1: # coils + spatial(3) if arr.shape[0] != 1: raise ValueError( f"Calibration KData has non-singleton 'other' dims; " f"got data shape {tuple(arr.shape)}." ) arr = arr[0] if self._ndim == 2: # (coils, 1, k1, k0) → (coils, k1, k0) arr = arr.squeeze(-3) calib_arr = arr.detach().cpu().numpy() else: calib_arr = calib super().calc_interp_table(calib_arr, lamda=lamda, precision=precision) # ------------------------------------------------------------------ def interpolate(self, kdata, *, return_plan: bool = True, grid: bool = False): """GROG-interpolate ``kdata`` onto the sparse Cartesian grid. Parameters ---------- kdata : mrpro.data.KData K-space data sharing the trajectory/encoding used at construction. return_plan : bool, optional If ``True`` (default) return the plan alongside the output. grid : bool, optional If ``True``, scatter the interpolated samples onto a dense oversampled Cartesian grid and return ``(KData, mask, density[, plan])``. The returned ``KData`` carries a regular Cartesian ``KTrajectory`` and has shape ``(*other, n_coils, k2, k1, k0)`` matching ``grid_shape`` (``k2=1`` for 2D plans). ``mask`` and ``density`` have shape ``(*stack, *grid_shape)`` and can be passed directly to :class:`~pygrog.operator.MaskedFFT`. Returns ------- KData (or (KData, GrogPlan)) When ``grid=False`` (default): ``data`` shape ``(*other, n_coils, k2', k1', k0')`` with the GROG kernel width fused into ``k0'``; ``traj`` updated to the gridded sample positions; ``header`` carried over. (KData, mask, density[, plan]) When ``grid=True``: dense Cartesian KData with Cartesian trajectory, plus real-valued ``mask`` and ``density`` tensors. """ from mrpro.data import KData, KTrajectory _coords_np, data_t, enc_shape, _ = _kdata_extract(kdata) if enc_shape != self._enc_shape: raise ValueError( f"KData encoding_matrix {enc_shape} does not match " f"interpolator {self._enc_shape}." ) # Drop singleton k2 for 2D plans before calling the base interpolate. data_for_interp = _data_to_spatial(data_t, self._ndim) if grid: out = super().interpolate(data_for_interp, grid=True) grid_kspace, mask, density = (torch.as_tensor(t) for t in out) # grid_kspace: (*batch, *stack, C, *grid_shape) # Re-insert singleton k2 for 2D → (*batch, *stack, C, 1, gy, gx) grid_data = _spatial_to_kdata(grid_kspace, self._ndim) # Build a regular Cartesian KTrajectory covering the oversampled grid. plan = self.plan grid_shape = tuple(int(s) for s in plan.grid_shape) traj_1d = [] for d, gs in enumerate(grid_shape): origin = float(-(enc_shape[d] // 2)) step = (enc_shape[d] - 1) / (gs - 1) if gs > 1 else 1.0 traj_1d.append(origin + torch.arange(gs, dtype=torch.float32) * step) # meshgrid → each tensor has shape (*grid_shape) grids = torch.meshgrid(*traj_1d, indexing="ij") if self._ndim == 2: # 2D: grids[0]=ky, grids[1]=kx, each (gy, gx) # mrpro KTrajectory layout: (*other, coil_singleton=1, k2, k1, k0) kz_new = torch.zeros(1, 1, 1, *grid_shape, dtype=torch.float32) ky_new = ( grids[0].unsqueeze(0).unsqueeze(0).unsqueeze(0) ) # (1,1,1,gy,gx) kx_new = grids[1].unsqueeze(0).unsqueeze(0).unsqueeze(0) else: # 3D: grids[0]=kz, grids[1]=ky, grids[2]=kx, each (gz,gy,gx) kz_new = grids[0].unsqueeze(0).unsqueeze(0) # (1,1,gz,gy,gx) ky_new = grids[1].unsqueeze(0).unsqueeze(0) kx_new = grids[2].unsqueeze(0).unsqueeze(0) new_traj = KTrajectory(kz=kz_new, ky=ky_new, kx=kx_new) new_kdata = KData(header=kdata.header, data=grid_data, traj=new_traj) if return_plan: return new_kdata, mask, density, plan return new_kdata, mask, density sparse = super().interpolate(data_for_interp) # Reshape from flat (*other, coils, n_samples) → (*other, coils, *spatial, kw) sparse = torch.as_tensor(sparse) sparse = sparse.reshape(*sparse.shape[:-1], *self.plan.natural_shape) # Fuse kw into the last spatial axis to match mrpro layout. *_lead, n_coils = ( *sparse.shape[: -(self._ndim + 1)], sparse.shape[-(self._ndim + 1)], ) spatial_kw = sparse.shape[-(self._ndim + 1) + 1 :] # (*spatial, kw) spatial = spatial_kw[:-1] kw = int(spatial_kw[-1]) # Flatten kw into the last spatial axis (k0 in mrpro layout). new_spatial = (*spatial[:-1], spatial[-1] * kw) new_data_shape = (*sparse.shape[: -(self._ndim + 1)], n_coils, *new_spatial) new_data = sparse.reshape(new_data_shape) # Re-insert singleton k2 for 2D so layout becomes (*other, coils, 1, k1, k0*kw). new_data = _spatial_to_kdata(new_data, self._ndim) # Build a new trajectory whose kx/ky/kz reflect the (oversampled) # Cartesian grid points each replicated GROG sample lands on. plan = self.plan target_idx = torch.as_tensor(plan.target_idx).to(torch.int64) # (*spatial, kw) grid_shape = tuple(int(s) for s in plan.grid_shape) # Decode flat indices into per-axis grid coordinates. idx = target_idx.clamp_min(0) coords_grid = [] rem = idx for d, gs in enumerate(grid_shape): stride = int(np.prod(grid_shape[d + 1 :])) if d + 1 < len(grid_shape) else 1 coords_grid.append((rem // stride) % gs) rem = rem - (rem // stride) * stride if stride > 1 else rem # Convert grid index → physical coords in pygrog scale [-shape/2, shape/2). # grid_steps_d = (shape_d - 1) / (grid_shape_d - 1) for each axis. phys = [] for d, gs in enumerate(grid_shape): origin = -(enc_shape[d] // 2) step = (enc_shape[d] - 1) / (gs - 1) if gs > 1 else 1.0 phys.append(origin + coords_grid[d].float() * step) # phys is list of tensors shape (*spatial, kw); stack & flatten kw into last axis. phys_stack = [ _spatial_to_kdata( p.reshape((*p.shape[:-2], p.shape[-2] * p.shape[-1])), self._ndim ) for p in phys ] # Each phys_stack[d] now has shape matching mrpro layout (k2', k1', k0'). # Reshape into KTrajectory's expected (*other=1, coil=1, k2', k1', k0'). def _to_traj(t): t = t.unsqueeze(0).unsqueeze(0) # (1, 1, k2', k1', k0') return t if self._ndim == 2: ky_new = _to_traj(phys_stack[0]) kx_new = _to_traj(phys_stack[1]) kz_new = torch.zeros_like(kx_new) else: kz_new = _to_traj(phys_stack[0]) ky_new = _to_traj(phys_stack[1]) kx_new = _to_traj(phys_stack[2]) new_traj = KTrajectory(kz=kz_new, ky=ky_new, kx=kx_new) new_kdata = KData(header=kdata.header, data=new_data, traj=new_traj) if return_plan: return new_kdata, plan return new_kdata # =========================================================================== # NLINV calibration (always pygrog backend — preferred over framework NLINV) # =========================================================================== def nlinv_calib( kdata, *, cal_width: int = 24, ret_cal: bool = False, ret_image: bool = False, **kwargs, ): """Estimate coil sensitivities for a ``KData`` via pygrog NLINV. Uses :func:`pygrog.utils.nlinv_calib` (preferred over any framework's own NLINV) and reshapes the result into mrpro's smap layout ``(n_coils, z, y, x)`` (broadcasting trivially against the ``(*other, n_coils, z, y, x)`` :class:`mrpro.data.IData`/ :class:`mrpro.operators.SensitivityOp` convention). Parameters ---------- kdata : mrpro.data.KData Source k-space. cal_width : int Calibration patch width. ret_cal, ret_image Forwarded to :func:`pygrog.utils.nlinv_calib`. **kwargs Additional keyword arguments forwarded to :func:`pygrog.utils.nlinv_calib`. Returns ------- smaps : torch.Tensor Coil sensitivities, shape ``(n_coils, z, y, x)`` (z=1 for 2D). *extras : optional Additional outputs from ``nlinv_calib`` (calibration k-space and/or reconstructed image), if requested. """ coords_np, data_t, enc_shape, recon_shape = _kdata_extract(kdata) ndim = len(enc_shape) # Allow leading *other dims (batch). Squeeze leading singleton other # axes so single-frame KData (the historic case) returns a result with # no leading batch dim. Trajectory is assumed shared across *other. arr = data_t # (*other, n_coils, k2, k1, k0) while arr.ndim > 4 and arr.shape[0] == 1: arr = arr[0] other_shape = tuple(int(s) for s in arr.shape[:-4]) if ndim == 2: # Drop singleton k2 → (*other, n_coils, k1, k0). arr_sf = arr.squeeze(-3) coords_np = coords_np.reshape(-1, 2) y_in = arr_sf.reshape(*other_shape, arr_sf.shape[-3], -1) else: coords_np = coords_np.reshape(-1, 3) y_in = arr.reshape(*other_shape, arr.shape[-4], -1) out = _nlinv_calib( y_in, cal_width=cal_width, shape=recon_shape, coords=coords_np, ret_cal=ret_cal, ret_image=ret_image, **kwargs, ) if not isinstance(out, tuple): out = (out,) smaps = out[0] smaps_t = torch.as_tensor(smaps) if ndim == 2: # (..., n_coils, y, x) → (..., n_coils, 1, y, x) smaps_t = smaps_t.unsqueeze(-3) extras = out[1:] if extras: return (smaps_t, *extras) return smaps_t # =========================================================================== # Coil compression — dispatch to mrpro's native KData.compress_coils # =========================================================================== def coil_compress(kdata, n_coils: int, *, batch_dims=None, joint_dims=...): """Coil-compress a ``KData`` via mrpro's native PCA compression. Thin dispatch to :meth:`mrpro.data.KData.compress_coils`; provided so callers using pygrog adapters across all three frameworks have a uniform API surface. Parameters ---------- kdata : mrpro.data.KData Source k-space. n_coils : int Number of virtual coils to retain. batch_dims, joint_dims Forwarded to :meth:`KData.compress_coils`. Returns ------- KData K-space with reduced coil dimension. """ return kdata.compress_coils(n_coils, batch_dims=batch_dims, joint_dims=joint_dims)