"""Masked FFT operator for gridded (Cartesian / undersampled) k-space.
Provides forward (gridded k-space → image) and adjoint (image → gridded
k-space) transforms for k-space data that have already been placed on a
Cartesian oversampled grid, i.e. the output of
:meth:`~pygrog.calib.GrogInterpolator.interpolate` with ``grid=True``.
Unlike :class:`~pygrog.operator.SparseFFT`, which works with *sparse*
non-Cartesian samples via scatter/gather, this operator applies a binary (or
density-weighted) mask directly on the oversampled k-space grid, making the
forward/adjoint transforms simple FFT + element-wise multiplication. This
is faster in low-dimensional (2D, 2D+t, multislice) settings where the full
oversampled grid fits comfortably in memory.
Coil combination is performed inside the operator:
- If sensitivity maps are provided, SENSE-style combination (forward) or
expansion (adjoint) is used.
- Otherwise, the output keeps the coil axis intact.
Both paths process data **coil-by-coil** to limit peak memory.
Weighting convention
--------------------
The ``density`` tensor is the sum of squared GROG weights at each grid cell:
``density[j] = Σ_i sqrt_weights[i]² · δ(indices[i] == j)``
The binary ``mask`` is derived from density by thresholding at zero:
``mask[j] = (density[j] > 0).to(real_dtype)``
For the forward operator (gridded k-space → image) the ``mask`` selects which
grid cells carry valid signal; the gridded k-space is expected to have been
density-compensated (multiplied by ``pre_weights`` before scattering) so no
additional weight multiplication is needed here.
For the Toeplitz normal operator the ``density`` grid **is** the PSF in the
oversampled k-space domain, identical to what :class:`~pygrog._toep.GrogToeplitzOp`
computes via scatter.
"""
__all__ = ["MaskedFFT", "MaskedFFTPlan"]
import numpy as np
import torch
from mrinufft._array_compat import with_torch
from .._base._fftc import fft, ifft
from .._utils import resize
from .._solve._mixin import SolveMixin
class MaskedFFTPlan:
"""Plan for :class:`MaskedFFT` — the Cartesian/masked counterpart of
:class:`~pygrog.calib.GrogPlan`.
Returned by :meth:`~pygrog.calib.GrogInterpolator.interpolate` with
``grid=True`` and accepted by :class:`MaskedFFT` via its *plan* argument,
giving the same one-liner workflow as the sparse path::
# Sparse path
sparse = grog.interpolate(kspace)
op = SparseFFT(plan=grog.plan, smaps=smaps)
image = op.adjoint(sparse * grog.plan.pre_weights)
# Dense/grid path — symmetric API
kgrid, plan = grog.interpolate(kspace, grid=True)
op = MaskedFFT(plan=plan, smaps=smaps)
image = op.adjoint(kgrid)
Parameters
----------
grid_shape : tuple[int, ...]
Oversampled Cartesian k-space grid shape (per stack element).
image_shape : tuple[int, ...]
Target image shape (center-crop target).
stack_shape : tuple[int, ...]
Leading stack axes; ``()`` for unstacked data.
mask : torch.Tensor
Real float binary sampling mask, shape ``(*stack_shape, *grid_shape)``.
density : torch.Tensor
Real float density grid (sum of squared GROG weights per cell),
same shape as *mask*. Used as the Toeplitz PSF by
:class:`~pygrog._toep.GrogToeplitzOp`.
"""
def __init__(
self,
grid_shape,
image_shape,
stack_shape,
mask,
density,
*,
coords=None,
dcf=None,
):
self.grid_shape = tuple(int(s) for s in grid_shape)
self.image_shape = tuple(int(s) for s in image_shape)
self.stack_shape = tuple(int(s) for s in stack_shape)
self.mask = torch.as_tensor(mask).float()
self.density = torch.as_tensor(density).float()
# Optional private metadata consumed by gadgets (e.g. off-resonance):
# ``_coords`` is the original non-Cartesian trajectory used to grid the
# data, and ``_dcf`` is the per-sample density compensation function.
# Together they let ORC re-cast the temporal basis on the Cartesian
# grid via an adjoint NUFFT of (readout_time * dcf).
self._coords = None if coords is None else torch.as_tensor(coords)
self._dcf = None if dcf is None else torch.as_tensor(dcf).float()
def __repr__(self):
return (
f"MaskedFFTPlan(grid_shape={self.grid_shape}, "
f"image_shape={self.image_shape}, "
f"stack_shape={self.stack_shape}, "
f"mask={tuple(self.mask.shape)}, "
f"density={tuple(self.density.shape)})"
)
[docs]
class MaskedFFT(SolveMixin):
"""Masked FFT / IFFT operator for gridded k-space data.
Accepts either a pre-built plan (from
:meth:`~pygrog.calib.GrogInterpolator.interpolate` with ``grid=True``)
together with explicit *mask* / *density* tensors, or raw shapes and masks
(for users who already have a Cartesian undersampling pattern and skip
GROG entirely).
Parameters
----------
grid_shape : tuple[int, ...]
Oversampled Cartesian k-space grid, e.g. ``(nz, ny, nx)``.
image_shape : tuple[int, ...]
Target image shape (center-crop), e.g. ``(nz, ny, nx)``.
mask : torch.Tensor
Binary sampling mask, real-valued, shape ``(*stack_shape, *grid_shape)``
or ``(*grid_shape,)`` for unstacked plans. 1 = sampled, 0 = not sampled.
density : torch.Tensor | None, optional
Density-compensation grid (sum of squared GROG weights), real-valued,
same shape as *mask*. Required for
:meth:`~pygrog._toep.GrogToeplitzOp` construction; if ``None`` the
Toeplitz normal operator is disabled.
smaps : torch.Tensor | None, optional
``(n_coils, *image_shape)`` sensitivity maps. *None* → no coil
combination.
stack_shape : tuple[int, ...], optional
Stack prefix, e.g. ``(n_slices,)`` or ``(T, n_slices)``. Inferred
from the leading axes of *mask* that are not part of *grid_shape*
when not provided.
device : str | torch.device | None, optional
Compute device.
toeplitz : bool | None, optional
Use Toeplitz embedding for :meth:`normal`. ``None`` → auto: enabled
on CPU, disabled on CUDA.
plan : object | None, optional
Pre-built plan (ignored except for logging; all required metadata
is passed via the other arguments).
"""
def __init__(
self,
grid_shape=None,
image_shape=None,
mask=None,
density=None,
smaps=None,
stack_shape=None,
device=None,
*,
toeplitz=None,
plan=None,
coords=None,
dcf=None,
):
# --- Accept MaskedFFTPlan or raw arguments -------------------------
_plan_coords = None if coords is None else torch.as_tensor(coords)
_plan_dcf = None if dcf is None else torch.as_tensor(dcf).float()
if plan is not None and isinstance(plan, MaskedFFTPlan):
grid_shape = plan.grid_shape
image_shape = plan.image_shape
mask = plan.mask
density = plan.density
if stack_shape is None:
stack_shape = plan.stack_shape
if _plan_coords is None:
_plan_coords = plan._coords
if _plan_dcf is None:
_plan_dcf = plan._dcf
elif plan is not None:
# Legacy: GrogPlan or other namespace passed as hint — ignore.
pass
if grid_shape is None or image_shape is None or mask is None:
raise ValueError(
"Either 'plan' (MaskedFFTPlan) or explicit 'grid_shape', "
"'image_shape', and 'mask' are required."
)
self.grid_shape = tuple(int(s) for s in grid_shape)
self.image_shape = tuple(int(s) for s in image_shape)
self.ndim = len(self.grid_shape)
self.grid_size = int(np.prod(self.grid_shape))
self.fft_axes = tuple(range(-self.ndim, 0))
# Mask: real float, shape (*stack_shape, *grid_shape) or (*grid_shape,)
mask_t = torch.as_tensor(mask).float()
# Infer stack_shape from leading dims if not supplied.
if stack_shape is None:
n_grid_dims = len(self.grid_shape)
if mask_t.ndim > n_grid_dims:
stack_shape = tuple(int(s) for s in mask_t.shape[:-n_grid_dims])
else:
stack_shape = ()
self.stack_shape = tuple(int(s) for s in stack_shape)
self.mask = mask_t # (*stack_shape, *grid_shape) or (*grid_shape,)
if density is not None:
self.density = torch.as_tensor(density).float()
else:
self.density = None
# Optional private metadata for gadgets (off-resonance correction).
# See :class:`MaskedFFTPlan` for semantics.
self._coords = _plan_coords
self._dcf = _plan_dcf
# natural_shape: for a MaskedFFT the "natural" k-space shape is
# (*grid_shape,) since input data is already gridded. This attribute
# is kept for API parity with SparseFFT so that decorators that read
# base.natural_shape work unchanged.
self.natural_shape = self.grid_shape
self.n_samples = self.grid_size # fictitious; kept for API parity
# Sensitivity maps
if smaps is not None:
self.smaps = torch.as_tensor(smaps)
self._conj_smaps = self.smaps.conj()
else:
self.smaps = None
self._conj_smaps = None
# Pre-compute center-crop slice (adjoint zero-pad bookkeeping)
self._pad_slices = tuple(
slice((gs - is_) // 2, (gs - is_) // 2 + is_)
for gs, is_ in zip(self.grid_shape, self.image_shape, strict=False)
)
self.device = torch.device(device) if device is not None else None
# Toeplitz: auto ON for CPU, OFF for CUDA (mirrors SparseFFT policy).
if toeplitz is None:
target = self.device if self.device is not None else torch.device("cpu")
toeplitz = target.type == "cpu"
# Toeplitz requires a density grid.
if toeplitz and self.density is None:
toeplitz = False
self.toeplitz = bool(toeplitz)
self._toep_op = None
# Stash plan reference (informational only).
self._plan = plan
# ------------------------------------------------------------------
# Stack helper
# ------------------------------------------------------------------
def _stack_mask(self, s_flat_idx: int):
"""Return ``(mask, density)`` for one flattened stack element.
For unstacked operators returns the full mask/density.
"""
if not self.stack_shape:
return self.mask, self.density
S_total = int(np.prod(self.stack_shape))
m = self.mask.reshape(S_total, *self.grid_shape)[s_flat_idx]
d = (
self.density.reshape(S_total, *self.grid_shape)[s_flat_idx]
if self.density is not None
else None
)
return m, d
# ------------------------------------------------------------------
# Forward: gridded k-space → image (adjoint NUFFT direction)
# ------------------------------------------------------------------
[docs]
@with_torch
def adjoint(self, kspace_grid: torch.Tensor) -> torch.Tensor:
"""Gridded k-space to image.
Parameters
----------
kspace_grid : torch.Tensor
Input gridded k-space, shape ``(*B, *S, n_coils, *grid_shape)``
(or without coil axis if smaps are set and caller pre-combined).
Returns
-------
torch.Tensor
``(*B, *S, *image_shape)`` if smaps are set (SENSE-combined),
else ``(*B, *S, n_coils, *image_shape)``.
"""
s_shape = self.stack_shape
s_ndim = len(s_shape)
grid_ndim = self.ndim
# Input trailing axes: (n_coils, *grid_shape)
expected_trailing = 1 + grid_ndim
prefix = tuple(int(s) for s in kspace_grid.shape[:-expected_trailing])
if s_ndim:
if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape:
raise ValueError(
f"kspace_grid prefix {prefix} must end with stack_shape {s_shape}"
)
B_shape = prefix[:-s_ndim]
else:
B_shape = prefix
n_coils = int(kspace_grid.shape[-grid_ndim - 1])
# Single-frame fast path (no batch, no stack).
if not prefix:
return self._forward_single(kspace_grid, 0)
S_total = int(np.prod(s_shape)) if s_shape else 1
B_total = int(np.prod(B_shape)) if B_shape else 1
flat = kspace_grid.reshape(B_total, S_total, n_coils, *self.grid_shape)
outs = []
for b in range(B_total):
frame_outs = []
for s in range(S_total):
frame_outs.append(self._forward_single(flat[b, s], s))
outs.append(torch.stack(frame_outs, dim=0))
stacked = torch.stack(outs, dim=0) # (B, S, ...)
single_out_shape = (
tuple(self.image_shape)
if self.smaps is not None
else (n_coils, *self.image_shape)
)
return stacked.reshape(*B_shape, *s_shape, *single_out_shape)
def _forward_single(self, kspace_grid: torch.Tensor, s_flat_idx: int = 0):
"""Single-frame forward. Input: ``(n_coils, *grid_shape)``."""
n_coils = int(kspace_grid.shape[0])
src_device = kspace_grid.device
comp_device = self.device if self.device is not None else src_device
dtype = kspace_grid.dtype
mask_s, _ = self._stack_mask(s_flat_idx)
mask_s = mask_s.to(
comp_device, dtype=dtype.to_real() if dtype.is_complex else dtype
)
kgrid = kspace_grid.to(comp_device)
if self.smaps is not None:
conj_smaps = self._conj_smaps.to(comp_device, dtype=dtype)
accum = torch.zeros(self.image_shape, dtype=dtype, device=comp_device)
else:
accum = torch.zeros(
n_coils, *self.image_shape, dtype=dtype, device=comp_device
)
for c in range(n_coils):
# Apply mask, IFFT, center-crop
masked = kgrid[c] * mask_s
full_img = ifft(masked, axes=self.fft_axes)
img_c = resize(full_img, self.image_shape)
if self.smaps is not None:
accum.addcmul_(img_c, conj_smaps[c])
else:
accum[c] = img_c
return accum.to(src_device)
# ------------------------------------------------------------------
# Adjoint: image → gridded k-space (forward NUFFT direction)
# ------------------------------------------------------------------
[docs]
@with_torch
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""Image to gridded k-space.
Parameters
----------
image : torch.Tensor
Input image, shape ``(*B, *S, *image_shape)`` if smaps set,
else ``(*B, *S, n_coils, *image_shape)``.
Returns
-------
torch.Tensor
``(*B, *S, n_coils, *grid_shape)``.
"""
s_shape = self.stack_shape
s_ndim = len(s_shape)
single_ndim = len(self.image_shape) + (0 if self.smaps is not None else 1)
prefix = tuple(int(s) for s in image.shape[: image.ndim - single_ndim])
if s_ndim:
if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape:
raise ValueError(
f"image prefix {prefix} must end with stack_shape {s_shape}"
)
B_shape = prefix[:-s_ndim]
else:
B_shape = prefix
if not prefix:
return self._adjoint_single(image, 0)
S_total = int(np.prod(s_shape)) if s_shape else 1
B_total = int(np.prod(B_shape)) if B_shape else 1
single_shape = tuple(image.shape[image.ndim - single_ndim :])
flat = image.reshape(B_total, S_total, *single_shape)
outs = []
for b in range(B_total):
frame_outs = []
for s in range(S_total):
frame_outs.append(self._adjoint_single(flat[b, s], s))
outs.append(torch.stack(frame_outs, dim=0))
stacked = torch.stack(outs, dim=0) # (B, S, n_coils, *grid_shape)
if self.smaps is not None:
n_coils = int(self.smaps.shape[0])
else:
n_coils = int(stacked.shape[2])
return stacked.reshape(*B_shape, *s_shape, n_coils, *self.grid_shape)
def _adjoint_single(self, image: torch.Tensor, s_flat_idx: int = 0):
"""Single-frame adjoint. Returns ``(n_coils, *grid_shape)``."""
src_device = image.device
comp_device = self.device if self.device is not None else src_device
dtype = image.dtype
mask_s, _ = self._stack_mask(s_flat_idx)
real_dtype = dtype.to_real() if dtype.is_complex else dtype
mask_s = mask_s.to(comp_device, dtype=real_dtype)
image_d = image.to(comp_device)
if self.smaps is not None:
smaps = self.smaps.to(comp_device, dtype=dtype)
n_coils = int(smaps.shape[0])
else:
n_coils = int(image_d.shape[0])
output = torch.empty(n_coils, *self.grid_shape, dtype=dtype, device=comp_device)
padded = torch.zeros(*self.grid_shape, dtype=dtype, device=comp_device)
for c in range(n_coils):
coil_img = image_d * smaps[c] if self.smaps is not None else image_d[c]
padded.zero_()
padded[self._pad_slices] = coil_img
kgrid = fft(padded, axes=self.fft_axes)
output[c] = kgrid * mask_s
return output.to(src_device)
def __call__(self, x, adjoint=False):
if adjoint:
return self.adjoint(x)
return self.forward(x)
# ------------------------------------------------------------------
# Normal operator: A^H A
# ------------------------------------------------------------------
[docs]
@with_torch
def normal(self, image: torch.Tensor) -> torch.Tensor:
"""Self-adjoint application: ``A^H A image``.
When ``self.toeplitz`` is True (requires ``density`` to be set),
uses a pre-computed PSF on ``grid_shape`` (built lazily on first
call). Otherwise falls back to ``forward(adjoint(image))`` (with
pygrog convention: ``adjoint = A`` (image→kspace),
``forward = A^H`` (kspace→image)).
"""
if self.toeplitz:
if self._toep_op is None:
from .._toep._grog_toep import GrogToeplitzOp
self._toep_op = GrogToeplitzOp(self, device=self.device)
return self._toep_op(image)
return self.adjoint(self.forward(image))
# ------------------------------------------------------------------
# Iterative solve
# ------------------------------------------------------------------
# Provided by SolveMixin (attached at module import time).
# ------------------------------------------------------------------
# Batch helpers for decorators
# ------------------------------------------------------------------
def _mask_ifft_crop_batch(
self,
batch_kspace_grid: torch.Tensor,
s_flat_idx: int = 0,
) -> torch.Tensor:
"""Apply mask, ONE batched IFFT, center-crop.
Parameters
----------
batch_kspace_grid : torch.Tensor
``(B, *grid_shape)`` complex gridded k-space (already density-
compensated).
s_flat_idx : int
Flattened stack-element index.
Returns
-------
torch.Tensor
``(B, *image_shape)`` complex.
"""
B = batch_kspace_grid.shape[0]
src_device = batch_kspace_grid.device
comp_device = self.device if self.device is not None else src_device
dtype = batch_kspace_grid.dtype
mask_s, _ = self._stack_mask(s_flat_idx)
real_dtype = dtype.to_real() if dtype.is_complex else dtype
mask_s = mask_s.to(comp_device, dtype=real_dtype)
kgrid = batch_kspace_grid.to(comp_device) * mask_s.unsqueeze(0)
full_imgs = ifft(kgrid, axes=self.fft_axes) # (B, *grid_shape)
imgs = resize(full_imgs, (B, *self.image_shape))
return imgs.to(src_device)
def _fft_pad_mask_batch(
self,
batch_imgs: torch.Tensor,
s_flat_idx: int = 0,
) -> torch.Tensor:
"""ONE batched FFT, zero-pad, apply mask.
Parameters
----------
batch_imgs : torch.Tensor
``(B, *image_shape)`` complex.
s_flat_idx : int
Flattened stack-element index.
Returns
-------
torch.Tensor
``(B, *grid_shape)`` complex, masked.
"""
B = batch_imgs.shape[0]
src_device = batch_imgs.device
comp_device = self.device if self.device is not None else src_device
dtype = batch_imgs.dtype
mask_s, _ = self._stack_mask(s_flat_idx)
real_dtype = dtype.to_real() if dtype.is_complex else dtype
mask_s = mask_s.to(comp_device, dtype=real_dtype)
imgs_d = batch_imgs.to(comp_device)
padded = torch.zeros(B, *self.grid_shape, dtype=dtype, device=comp_device)
padded[(slice(None), *self._pad_slices)] = imgs_d
kgrid = fft(padded, axes=self.fft_axes) # (B, *grid_shape)
return (kgrid * mask_s.unsqueeze(0)).to(src_device)