Source code for pygrog.interop._deepinverse

"""deepinv LinearPhysics adapter for SparseFFT.

Wraps :class:`~pygrog.operator.SparseFFT` as a
``deepinv.physics.LinearPhysics`` so it plugs into deepinv reconstruction
algorithms (FISTA, ADMM, PnP, RED, unrolled networks, …) without
modification.

Shape conventions (matching deepinv's ``(B, C, *spatial)`` layout)
==================================================================

The adapter speaks deepinv's native batched shapes at its boundary so it
can be dropped directly into ``optim_builder``, ``Trainer`` and the rest of
the deepinv ecosystem:

* **image**:   ``(B, 1, *image_shape)`` complex.  ``C == 1`` because the
  wrapped operator already coil-combines via ``smaps``.
* **k-space**: ``(B, n_coils, n_samples)`` complex (flattened sparse layout
    - deepinv treats measurements as opaque tensors, so the flat layout is
  the most permissive choice and avoids forcing callers to reshape on
  every call).

Gradients are computed via :mod:`pygrog.interop._torch` — explicit
``torch.autograd.Function`` subclasses whose backward is the adjoint of
the measurement operator — rather than relying on automatic
differentiation through the GROG kernels.

deepinv ``LinearPhysics`` contract:
  - Subclass ``deepinv.physics.LinearPhysics``
  - Implement ``A(x)``         : image  → k-space  (measurement)
  - Implement ``A_adjoint(y)`` : k-space → image   (backprojection)
  - ``A_dagger`` (pseudoinverse) is then provided by the base class.
"""

__all__ = ["GrogInterpolator", "GrogLinearPhysics", "coil_compress", "nlinv_calib"]

import numpy as np
import torch

from ..calib import GrogInterpolator as _GrogInterpolatorBase
from ..utils import coil_compress as _coil_compress
from ..utils import nlinv_calib as _nlinv_calib
from ._torch import grog_backproject, grog_measure


[docs] class GrogLinearPhysics: """Wrap a pygrog operator as a ``deepinv.physics.LinearPhysics``. Because deepinv is an optional dependency, the concrete subclass is built lazily on first instantiation. Parameters ---------- op : SparseFFT-like Any operator with ``forward(kspace) -> image`` and ``adjoint(image) -> kspace`` methods. noise_model : deepinv.physics.NoiseModel or None, optional Noise model to attach. Defaults to ``deepinv.physics.ZeroNoise()``. Raises ------ ImportError If ``deepinv`` is not installed. Examples -------- :: from pygrog.operator import SparseFFT from pygrog.interop import GrogLinearPhysics op = SparseFFT(plan=grog.plan, smaps=smaps) physics = GrogLinearPhysics(op) # x: (B, 1, H, W) complex, y: (B, n_coils, n_samples) complex y = physics(x) x_hat = physics.A_dagger(y) """ _deepinv_class = None # cached at class level def __new__(cls, op, noise_model=None): if cls._deepinv_class is None: cls._deepinv_class = cls._build_class() return cls._deepinv_class(op, noise_model=noise_model) @staticmethod def _build_class(): try: from deepinv.physics import LinearPhysics, ZeroNoise except ImportError as exc: raise ImportError( "deepinv is required for GrogLinearPhysics. " "Install it with: pip install deepinv" ) from exc class _GrogLinearPhysicsImpl(LinearPhysics): """deepinv LinearPhysics wrapping a pygrog SparseFFT-like operator. Hides the (n_coils, *natural_shape) sparse-k-space layout behind deepinv's ``(B, C, *spatial)`` convention. Coils live in the ``C`` axis on the measurement side; the image side has ``C=1`` because the wrapped operator coil-combines via ``smaps``. """ def __init__(self, op, noise_model=None): super().__init__(noise_model=noise_model or ZeroNoise()) self._op = op self._n_coils = ( int(op.smaps.shape[0]) if getattr(op, "smaps", None) is not None else None ) self._n_samples = int(op.indices.shape[0]) # ---- shape adapters ------------------------------------------ def _strip_image(self, x: torch.Tensor) -> tuple[torch.Tensor, int]: """``(B, 1, *image)`` → ``(B, *image)`` for the wrapped op.""" if x.ndim < 2 or x.shape[1] != 1: raise ValueError( "GrogLinearPhysics expects images of shape " f"(B, 1, *image_shape); got {tuple(x.shape)}." ) return x[:, 0], x.shape[0] def _wrap_image(self, x: torch.Tensor) -> torch.Tensor: return x.unsqueeze(1) def _strip_kspace(self, y: torch.Tensor) -> torch.Tensor: """``(B, n_coils, n_samples)`` → ``(B, n_coils, n_samples)``. A no-op shape check; provided for symmetry with image side. """ if y.ndim != 3 or y.shape[2] != self._n_samples: raise ValueError( "GrogLinearPhysics expects k-space of shape " f"(B, n_coils, {self._n_samples}); got {tuple(y.shape)}." ) return y # ---- LinearPhysics API --------------------------------------- def A(self, x: torch.Tensor, **kwargs) -> torch.Tensor: # noqa: ARG002 """Forward measurement: image → k-space. Parameters ---------- x : torch.Tensor Image of shape ``(B, 1, *image_shape)``. Returns ------- torch.Tensor K-space of shape ``(B, n_coils, n_samples)``. """ img, _ = self._strip_image(x) ksp = grog_measure(img, self._op) return ksp.reshape(ksp.shape[0], ksp.shape[1], -1) def A_adjoint( self, y: torch.Tensor, **kwargs, # noqa: ARG002 ) -> torch.Tensor: """Backprojection: k-space → image. Parameters ---------- y : torch.Tensor K-space of shape ``(B, n_coils, n_samples)``. Returns ------- torch.Tensor Image of shape ``(B, 1, *image_shape)``. """ y = self._strip_kspace(y) img = grog_backproject(y, self._op) return self._wrap_image(img) def A_adjoint_A( self, x: torch.Tensor, **kwargs, # noqa: ARG002 ) -> torch.Tensor: """Self-adjoint operator ``A^H A``. Routes through the wrapped operator's ``.normal()`` method, which uses the Toeplitz embedding when the op was constructed with ``toeplitz=True`` (default on CPU). Parameters ---------- x : torch.Tensor Image of shape ``(B, 1, *image_shape)``. Returns ------- torch.Tensor Same shape as input. """ img, _ = self._strip_image(x) # Apply over the batch dim. out = torch.stack( [self._op.normal(img[b]) for b in range(img.shape[0])], dim=0, ) return self._wrap_image(out) return _GrogLinearPhysicsImpl
# =========================================================================== # GrogInterpolator adapter (deepinv batched-tensor I/O) # =========================================================================== class GrogInterpolator(_GrogInterpolatorBase): """GROG interpolator with deepinv-style batched torch I/O. deepinv has no native non-Cartesian data container; users typically pass batched tensors through their custom ``Physics`` object. This adapter accepts and returns ``(B, ...)`` tensors so it composes naturally with deepinv pipelines. Parameters ---------- coords : np.ndarray | torch.Tensor, shape ``(*spatial, ndim)`` Trajectory coordinates in pygrog scale. shape : int | tuple[int, ...] Cartesian image shape spanned by the trajectory. image_shape, kernel_width, oversamp, kernel_shape, time_map Forwarded to :class:`pygrog.calib.GrogInterpolator`. """ def __init__( self, coords, shape, *, image_shape=None, kernel_width: int = 2, oversamp: float | list | tuple | None = None, kernel_shape: str = "circle", time_map=None, ): coords_np = ( coords.detach().cpu().numpy() if isinstance(coords, torch.Tensor) else np.asarray(coords) ) super().__init__( shape=shape, coords=coords_np, oversamp=oversamp, kernel_width=kernel_width, kernel_shape=kernel_shape, time_map=time_map, image_shape=image_shape, ) def interpolate(self, kspace, *, return_plan: bool = True, **kwargs): """Interpolate batched non-Cartesian k-space. Parameters ---------- kspace : torch.Tensor ``(B, n_coils, *spatial)`` complex k-space. Returns ------- sparse : torch.Tensor ``(B, n_coils, *spatial, kw)`` interpolated k-space (in pygrog "natural" layout — the kw axis is exposed because deepinv has no notion of ``k0``). plan : pygrog.calib.GrogPlan, optional Returned when ``return_plan=True`` (default). """ if not isinstance(kspace, torch.Tensor): raise TypeError("GrogInterpolator (deepinv) expects a torch.Tensor.") if kspace.ndim < 3: raise ValueError( "Expected (B, n_coils, *spatial) tensor; got shape " f"{tuple(kspace.shape)}." ) out = super().interpolate(kspace, **kwargs) grid = kwargs.get("grid", False) if grid: grid_kspace, masked_plan = out grid_kspace = torch.as_tensor(grid_kspace) if return_plan: return grid_kspace, masked_plan, self.plan return grid_kspace, masked_plan out_t = torch.as_tensor(out) # Reshape from flat (*batch, C, n_samples) → (*batch, C, *natural_shape) out_t = out_t.reshape(*out_t.shape[:-1], *self.plan.natural_shape) if return_plan: return out_t, self.plan return out_t # =========================================================================== # NLINV calibration — pygrog backend (preferred over any framework NLINV) # =========================================================================== def nlinv_calib( kspace, coords, shape, *, cal_width: int = 24, ret_cal: bool = False, ret_image: bool = False, **kwargs, ): """Estimate coil sensitivities, returned in deepinv batched layout. Parameters ---------- kspace : torch.Tensor ``(B, n_coils, n_samples)`` or ``(n_coils, n_samples)`` k-space. coords : torch.Tensor | np.ndarray Trajectory ``(n_samples, ndim)`` in pygrog scale. shape : tuple[int, ...] Image shape. cal_width : int Calibration patch width. **kwargs Forwarded to :func:`pygrog.utils.nlinv_calib`. Returns ------- smaps : torch.Tensor ``(B, n_coils, *shape)`` coil sensitivities (deepinv "B, C=coils" layout). ``B`` is preserved from the input. *extras : optional Additional outputs from ``nlinv_calib``. """ if isinstance(kspace, torch.Tensor): ks = kspace has_batch = ks.ndim >= 3 if not has_batch: ks_in = ks # (n_coils, n_samples) else: ks_in = ks # (B, n_coils, n_samples) — passes through batched path int(ks.shape[0]) else: ks_in = np.asarray(kspace) ks_in.shape[0] if ks_in.ndim >= 3 else 1 coords_np = ( coords.detach().cpu().numpy() if isinstance(coords, torch.Tensor) else np.asarray(coords) ) out = _nlinv_calib( ks_in, cal_width=cal_width, shape=tuple(shape), coords=coords_np, ret_cal=ret_cal, ret_image=ret_image, **kwargs, ) if not isinstance(out, tuple): out = (out,) smaps = torch.as_tensor(out[0]) if smaps.ndim == 1 + len(shape): # single-frame result, add B=1 smaps = smaps.unsqueeze(0) extras = tuple(torch.as_tensor(e) for e in out[1:]) if extras: return (smaps, *extras) return smaps # =========================================================================== # Coil compression — pygrog backend # =========================================================================== def coil_compress( kspace, n_coils, *, traj=None, krad_thresh: float | None = None, ): """Coil-compress batched non-Cartesian k-space. Parameters ---------- kspace : torch.Tensor ``(B, n_coils, n_samples)`` or ``(n_coils, n_samples)`` k-space. n_coils : int | float Target virtual-coil count or energy threshold. traj : torch.Tensor | np.ndarray, optional Trajectory for radius-based calibration extraction. krad_thresh : float, optional Relative k-space radius threshold. Returns ------- compressed : torch.Tensor Same leading shape as input, with reduced coil dimension. matrix : torch.Tensor Compression matrix ``(n_virtual, n_coils)``. """ is_torch = isinstance(kspace, torch.Tensor) has_batch = is_torch and kspace.ndim == 3 and kspace.shape[0] != 0 if is_torch and has_batch: if kspace.shape[0] != 1: raise ValueError( f"coil_compress expects batch size 1; got shape {tuple(kspace.shape)}." ) ks_in = kspace[0] else: ks_in = kspace traj_in = ( traj.detach().cpu().numpy() if isinstance(traj, torch.Tensor) else (None if traj is None else np.asarray(traj)) ) compressed, matrix = _coil_compress( ks_in, n_coils, traj=traj_in, krad_thresh=krad_thresh, ) compressed_t = torch.as_tensor(compressed) matrix_t = torch.as_tensor(matrix) if is_torch and has_batch: compressed_t = compressed_t.unsqueeze(0) return compressed_t, matrix_t