Source code for pygrog.operator._sparse_fft

"""Sparse FFT operator for GROG-gridded data.

Provides forward (sparse-to-dense / adjoint NUFFT) and adjoint (dense-to-sparse
/ forward NUFFT) transforms using pre-computed GROG plan metadata (indices,
weights, grid/image shapes).

Coil combination is performed inside the operator:

- If sensitivity maps are provided, SENSE-style combination (forward) or
  expansion (adjoint) is used.
- Otherwise, root-sum-of-squares combination is applied (forward only;
  adjoint assumes single-channel input).

Both paths process data **coil-by-coil** to limit peak memory.

Weights follow the convention that GROG-interpolated data are pre-multiplied
by ``weights**0.5`` once right after interpolation, and the same ``weights**0.5``
is applied inside both forward and adjoint for orthonormality.

Optional dual-stream GPU pipelining: when ``device`` is set and differs from the
data device, each coil is transferred asynchronously on alternating CUDA streams
while the previous coil's FFT executes concurrently.
"""

__all__ = ["SparseFFT", "gather", "scatter_add"]

import pathlib

import numpy as np
import torch
from mrinufft._array_compat import with_torch

from .._base._fftc import fft, ifft
from .._utils import resize
from .._solve._mixin import SolveMixin

# ---------------------------------------------------------------------------
# Torch C++ extension (lazy, cached)
# ---------------------------------------------------------------------------
_torch_ext = None
_torch_ext_checked = False


def _get_torch_ext():
    """Return the ``_pygrog_torch`` extension module.

    Load priority:
      1. Pre-compiled wheel extension (``pygrog._pygrog_torch``)
      2. JIT compilation via ``torch.utils.cpp_extension.load()``

    Raises
    ------
    RuntimeError
        If neither the pre-built extension nor JIT compilation succeeds.
    """
    global _torch_ext, _torch_ext_checked
    if _torch_ext_checked:
        return _torch_ext

    # 1. try pre-built wheel extension
    try:
        import pygrog._pygrog_torch as _ext

        _torch_ext = _ext
        _torch_ext_checked = True
        return _torch_ext
    except ImportError:
        pass

    # 2. JIT compilation
    jit_error = None
    try:
        import torch.utils.cpp_extension as _cpp_ext

        _HERE = pathlib.Path(__file__).parent.parent.parent.parent.parent
        _CSRC = _HERE / "csrc" / "torch"

        sources = [
            str(_CSRC / "module.cpp"),
            str(_CSRC / "grog_interp.cpp"),
            str(_CSRC / "sparse_ops.cpp"),
            str(_CSRC / "sparse_ops_avx2.cpp"),
            str(_CSRC / "sparse_ops_avx512.cpp"),
        ]
        define_macros = []
        if torch.cuda.is_available() and (_CSRC / "sparse_ops_cuda.cu").exists():
            sources.append(str(_CSRC / "grog_interp_cuda.cu"))
            sources.append(str(_CSRC / "sparse_ops_cuda.cu"))
            define_macros.append(("COMPILE_WITH_CUDA", "1"))

        _torch_ext = _cpp_ext.load(
            name="_pygrog_torch_jit",
            sources=sources,
            extra_cflags=[
                "-O3",
                "-std=c++17",
                "-fopenmp",
                "-march=native",
                "-DPYGROG_MARCH_NATIVE",
            ],
            extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr"],
            extra_ldflags=["-fopenmp"],
            define_macros=define_macros,
            verbose=False,
        )
        _torch_ext_checked = True
        return _torch_ext
    except Exception as exc:
        jit_error = exc

    raise RuntimeError(
        "pygrog requires the _pygrog_torch C++ extension but it could not "
        "be loaded.  Install from a precompiled wheel (`pip install pygrog`) "
        "or build from source with a C++17 compiler "
        "(`pip install --no-build-isolation -e .`).\n"
        f"JIT compilation error: {jit_error}"
    )


# ---------------------------------------------------------------------------
# scatter_add / gather wrappers
# ---------------------------------------------------------------------------
def _scatter_add(grid, data, indices, weights, bin_starts=None, bin_size=0):
    """grid[indices[i]] += weights[i] * data[i], in-place."""
    ext = _get_torch_ext()
    if bin_starts is not None:
        ext.scatter_add_binned(grid, data, indices, weights, bin_starts, bin_size)
    else:
        ext.scatter_add(grid, data, indices, weights)


def _gather(grid, indices, weights):
    """out[i] = weights[i] * grid[indices[i]]."""
    ext = _get_torch_ext()
    return ext.gather(grid, indices, weights)


def scatter_add(
    grid: torch.Tensor,
    data: torch.Tensor,
    indices: torch.Tensor,
    weights: torch.Tensor,
) -> None:
    """Scatter-add: ``grid[indices[i]] += weights[i] * data[i]`` (in-place).

    Parameters
    ----------
    grid : torch.Tensor
        Flat output grid (complex), modified in-place.
    data : torch.Tensor
        Input data values (complex), 1-D.
    indices : torch.Tensor
        Target grid indices (int64), 1-D.
    weights : torch.Tensor
        Per-sample real weights (float), 1-D.
    """
    _scatter_add(grid, data, indices, weights)


def gather(
    grid: torch.Tensor,
    indices: torch.Tensor,
    weights: torch.Tensor,
) -> torch.Tensor:
    """Gather: ``out[i] = weights[i] * grid[indices[i]]``.

    Parameters
    ----------
    grid : torch.Tensor
        Flat input grid (complex), 1-D.
    indices : torch.Tensor
        Source indices (int64), 1-D.
    weights : torch.Tensor
        Per-sample real weights (float), 1-D.

    Returns
    -------
    torch.Tensor
        Gathered values (complex), 1-D.
    """
    return _gather(grid, indices, weights)


# =====================================================================
# SparseFFT
# =====================================================================
[docs] class SparseFFT(SolveMixin): """Sparse FFT / IFFT operator with coil combination. Accepts either a pre-built plan (from :meth:`~pygrog.grog.GrogInterpolator.fft_plan`) or raw arrays. When a plan is provided, sorted indices, sqrt-weights, and permutation arrays are reused directly; otherwise they are computed from the raw ``indices`` / ``weights`` arguments. Parameters ---------- plan : SimpleNamespace | None Pre-built plan from ``GrogInterpolator.fft_plan()``. If given, *grid_shape*, *image_shape*, *indices*, and *weights* are ignored. grid_shape : tuple[int, ...] | None Oversampled Cartesian k-space grid, e.g. ``(nz, ny, nx)``. image_shape : tuple[int, ...] | None Target image shape (center-crop), e.g. ``(nz, ny, nx)``. indices : array-like | None Flat grid indices ``(n_samples,)`` int64. weights : array-like | None Density-compensation weights ``(n_samples,)`` float32. ``sqrt(weights)`` is applied in both directions. smaps : torch.Tensor | None ``(n_coils, *image_shape)`` sensitivity maps. *None* → RSS. device : str | torch.device | None Compute device. When ``'cuda'`` with CPU data, dual-stream pipelining is enabled. toeplitz : bool | None, optional Use Toeplitz embedding (PSF on `grid_shape`) for the self-adjoint operator :meth:`normal`. ``None`` → auto: enabled on CPU, disabled on CUDA (matches :func:`pygrog.utils.nlinv` policy). """ def __init__( self, grid_shape=None, image_shape=None, indices=None, weights=None, smaps=None, device=None, *, plan=None, toeplitz=None, ): # --- Accept plan or raw arguments ---------------------------------- if plan is not None: self.grid_shape = tuple(plan.grid_shape) self.image_shape = tuple(plan.image_shape) self.grid_size = int(plan.grid_size) self.indices = plan.indices self.sqrt_weights = plan.sqrt_weights self.sort_perm = plan.sort_perm self.inv_perm = plan.inv_perm self.natural_shape = tuple( int(s) for s in getattr(plan, "natural_shape", (int(plan.n_samples),)) ) self.stack_shape = tuple( int(s) for s in getattr(plan, "stack_shape", ()) or () ) else: if grid_shape is None or image_shape is None: raise ValueError( "Either 'plan' or both 'grid_shape'/'image_shape' required" ) self.grid_shape = tuple(grid_shape) self.image_shape = tuple(image_shape) self.grid_size = int(np.prod(grid_shape)) idx = torch.as_tensor(indices).ravel().to(torch.int64) w = torch.as_tensor(weights).ravel().to(torch.float32) sqrt_w = torch.sqrt(w) sort_perm = torch.argsort(idx) self.sort_perm = sort_perm self.indices = idx[sort_perm] self.sqrt_weights = sqrt_w[sort_perm] inv_perm = torch.empty_like(sort_perm) inv_perm[sort_perm] = torch.arange(len(sort_perm)) self.inv_perm = inv_perm self.natural_shape = (int(idx.numel()),) self.stack_shape = () # n_samples is the per-stack-element sample count. When stacked, # self.indices has shape (*stack_shape, n_samples); else (n_samples,). self.n_samples = int(self.indices.shape[-1]) self.ndim = len(self.grid_shape) self.fft_axes = tuple(range(-self.ndim, 0)) # GPU binning — computed lazily on first use self._bin_starts = None self._bin_size = 0 # Hot-cell scatter acceleration (CUDA). The k-space sampling # frequently has a few cells (e.g. DC for radial trajectories) # that absorb most samples; doing those via a sum-reduction avoids # severe atomicAdd contention. Computed lazily. self._hot_uidx = None # (n_hot,) int64 grid indices self._hot_starts = None # list[int] starts in sorted sample order self._hot_ends = None # list[int] ends in sorted sample order self._cold_seg_starts = None # list[int] cold range starts self._cold_seg_ends = None # list[int] cold range ends # Packed-stack arrays — built lazily on first stacked op (see # ``_packed_arrays``). Cached per compute device. self._packed_cache = None # Sensitivity maps — pre-compute conjugate (view, free) if smaps is not None: self.smaps = torch.as_tensor(smaps) self._conj_smaps = self.smaps.conj() else: self.smaps = None self._conj_smaps = None # Pre-compute center-slice for adjoint zero-pad self._pad_slices = tuple( slice((gs - is_) // 2, (gs - is_) // 2 + is_) for gs, is_ in zip(self.grid_shape, self.image_shape, strict=False) ) self.device = torch.device(device) if device is not None else None # --- Toeplitz acceleration ----------------------------------------- # Auto-toggle: ON when target compute is CPU, OFF on CUDA (mirrors # `pygrog.utils.nlinv` policy). User can force True/False. if toeplitz is None: target = self.device if self.device is not None else torch.device("cpu") toeplitz = target.type == "cpu" self.toeplitz = bool(toeplitz) self._toep_op = None # lazily built on first .normal() call # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _ensure_bins(self, device): """Compute CUDA scatter acceleration metadata on first use. Two strategies are prepared: * Hot-cell handling — identifies grid cells that receive a large number of samples (typically DC for radial / MRF trajectories, where >65% of samples can map to a single voxel). These cells would otherwise serialize ``atomicAdd`` and dominate the scatter runtime. We instead reduce them via per-segment sums. * Bin layout for the binned-shmem kernel (legacy fallback; only used when no hot cells are detected). """ if device.type != "cuda" or self._bin_starts is not None: return if self.stack_shape: # No bin caching for stacked plans (per-stack bins would need # rebuilding inside the forward loop). This only affects CUDA # micro-optimisation; correctness is unchanged. return sorted_idx = self.indices.to(device) # --- hot-cell detection ------------------------------------------- unique_idx, counts = torch.unique_consecutive(sorted_idx, return_counts=True) ends = torch.cumsum(counts, 0) starts = ends - counts # Threshold: cells with > sqrt(n_samples) entries are "hot". # In practice this catches the DC peak and a handful of high-density # cells that would otherwise serialise atomicAdd. threshold = max(1024, int(self.n_samples**0.5)) hot_mask = counts > threshold n_hot = int(hot_mask.sum().item()) if n_hot > 0: hot_starts = starts[hot_mask] hot_ends = ends[hot_mask] hot_uidx = unique_idx[hot_mask] # Sort hot ranges by start so the cold segments between them # are well-defined contiguous slices. order = torch.argsort(hot_starts) hot_starts = hot_starts[order] hot_ends = hot_ends[order] hot_uidx = hot_uidx[order].contiguous() hs = hot_starts.cpu().tolist() he = hot_ends.cpu().tolist() cold_s, cold_e = [], [] prev = 0 for s, e in zip(hs, he, strict=False): if s > prev: cold_s.append(prev) cold_e.append(s) prev = e if prev < self.n_samples: cold_s.append(prev) cold_e.append(self.n_samples) self._hot_uidx = hot_uidx self._hot_starts = hs self._hot_ends = he self._cold_seg_starts = cold_s self._cold_seg_ends = cold_e # Sentinel: mark scatter setup as done (keeps the # ``self._bin_starts is not None`` early-out working). self._bin_starts = hot_uidx self._bin_size = -1 # negative => hot-cell path active return # --- legacy binned shared-memory path (no hot cells) -------------- self._bin_size = 256 n_bins = (self.grid_size + self._bin_size - 1) // self._bin_size bin_edges = torch.arange(n_bins + 1, dtype=torch.int64) * self._bin_size bin_edges[-1] = self.grid_size self._bin_starts = torch.searchsorted(sorted_idx, bin_edges.to(device)) def _scatter(self, grid, data, indices, sqrt_w): """Weighted scatter-add ``grid[idx[i]] += sqrt_w[i] * data[i]``. Routes to: * hot-cell + atomic path (CUDA, when hot cells were detected and the caller is using the cached sorted plan ``self.indices``); * binned shared-memory kernel (CUDA, no hot cells); * plain atomic scatter (CPU and stacked / per-stack callers). """ ext = _get_torch_ext() # Fast path: hot-cell handling. Only valid when called with the # cached sorted plan indices (true for all non-stacked uses; stacked # plans skip ``_ensure_bins`` and therefore have ``_hot_uidx is None``). if ( self._hot_uidx is not None and indices.device == self._hot_uidx.device and indices.numel() == self.n_samples ): # Cold segments — contiguous slices, low-contention atomic. for s, e in zip(self._cold_seg_starts, self._cold_seg_ends, strict=False): ext.scatter_add(grid, data[s:e], indices[s:e], sqrt_w[s:e]) # Hot segments — sum-reduce each, then a single index_add. hot_sums = torch.empty( len(self._hot_starts), dtype=data.dtype, device=data.device ) for i, (s, e) in enumerate( zip(self._hot_starts, self._hot_ends, strict=False) ): hot_sums[i] = (data[s:e] * sqrt_w[s:e]).sum() grid.index_add_(0, self._hot_uidx, hot_sums) return _scatter_add( grid, data, indices, sqrt_w, bin_starts=self._bin_starts if self._bin_size > 0 else None, bin_size=self._bin_size if self._bin_size > 0 else 0, ) # ------------------------------------------------------------------ # Per-stack 1-D selectors (for stacked plans) # ------------------------------------------------------------------ def _stack_arrays(self, s_flat_idx: int): """Return ``(indices, sqrt_weights, sort_perm, inv_perm)`` for one flattened stack element. If unstacked, returns the global arrays. """ if not self.stack_shape: return self.indices, self.sqrt_weights, self.sort_perm, self.inv_perm s_total = int(np.prod(self.stack_shape)) idx = self.indices.reshape(s_total, self.n_samples)[s_flat_idx] sqw = self.sqrt_weights.reshape(s_total, self.n_samples)[s_flat_idx] sp = self.sort_perm.reshape(s_total, self.n_samples)[s_flat_idx] ip = self.inv_perm.reshape(s_total, self.n_samples)[s_flat_idx] return idx, sqw, sp, ip # ------------------------------------------------------------------ # Packed-stack arrays — fuse all stack elements into a single sorted # 1-D problem on a flat super-grid of shape ``(S_total * grid_size,)``. # Per-stack indices are offset by ``s * grid_size`` so the concatenated # array is globally sorted (each stack lives in its own disjoint slab), # which lets the single-trajectory C++ kernels handle the whole stack # in one call without naive Python looping. # ------------------------------------------------------------------ def _packed_arrays(self, comp_device: torch.device): """Return packed ``(indices, sqrt_w, sort_perm, inv_perm, S, n_per)``. Each output is a 1-D tensor of length ``S_total * n_per`` on ``comp_device``; ``indices`` are offset so writes to the flat super- grid land in disjoint per-stack slabs. Cached per device. """ if not self.stack_shape: raise RuntimeError("_packed_arrays called on an unstacked plan") cache = getattr(self, "_packed_cache", None) if cache is not None and cache[0] == comp_device: return cache[1] S_total = int(np.prod(self.stack_shape)) n_per = int(self.n_samples) idx2d = self.indices.reshape(S_total, n_per).to(comp_device) sqw2d = self.sqrt_weights.reshape(S_total, n_per).to(comp_device) sp2d = self.sort_perm.reshape(S_total, n_per).to(comp_device) ip2d = self.inv_perm.reshape(S_total, n_per).to(comp_device) grid_off = ( torch.arange(S_total, device=comp_device, dtype=idx2d.dtype) * self.grid_size ).unsqueeze(-1) data_off = ( torch.arange(S_total, device=comp_device, dtype=sp2d.dtype) * n_per ).unsqueeze(-1) packed = ( (idx2d + grid_off).reshape(-1).contiguous(), sqw2d.reshape(-1).contiguous(), (sp2d + data_off).reshape(-1).contiguous(), (ip2d + data_off).reshape(-1).contiguous(), S_total, n_per, ) self._packed_cache = (comp_device, packed) return packed # ------------------------------------------------------------------ # Forward: sparse k-space -> image (adjoint NUFFT direction) # ------------------------------------------------------------------
[docs] @with_torch def adjoint(self, sparse_kspace: torch.Tensor) -> torch.Tensor: """Sparse k-space to image. Accepted input layouts (with optional leading ``*B`` batch and, for stacked plans, leading ``*S`` stack axes inserted between batch and single-frame dims): - ``(*B, *S, n_coils, *natural_shape)`` - ``(*B, *S, n_coils, n_samples)`` (legacy / flat form) Output: ``(*B, *S, *image_shape)`` if smaps are set (SENSE-combined), else ``(*B, *S, n_coils, *image_shape)``. """ # Natural-shape (multi-dim) input → fold trailing dims into n_samples. nat = self.natural_shape if ( len(nat) > 1 and tuple(int(s) for s in sparse_kspace.shape[-len(nat) :]) == nat ): flat_shape = ( *tuple(int(s) for s in sparse_kspace.shape[: -len(nat)]), self.n_samples, ) sparse_kspace = sparse_kspace.reshape(flat_shape) # x now has trailing (n_coils, n_samples). Split prefix into (*B, *S). s_shape = self.stack_shape s_ndim = len(s_shape) prefix = tuple(int(s) for s in sparse_kspace.shape[:-2]) if s_ndim: if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape: raise ValueError( f"sparse_kspace prefix {prefix} must end with stack_shape {s_shape}" ) B_shape = prefix[:-s_ndim] else: B_shape = prefix # Single-frame fast path (no batch, no stack). if not prefix: return self._forward_single(sparse_kspace, 0) # Loop over flattened (*B, *S). S_total = int(np.prod(s_shape)) if s_shape else 1 B_total = int(np.prod(B_shape)) if B_shape else 1 # Flatten to (B_total, S_total, n_coils, n_samples) n_coils = int(sparse_kspace.shape[-2]) flat = sparse_kspace.reshape(B_total, S_total, n_coils, self.n_samples) # Stacked path: fuse all S stack elements into a single packed # scatter call (one C++ kernel invocation handles every stack # element at once via offset indices into a flat super-grid). if s_ndim: outs = [] for b in range(B_total): outs.append(self._forward_packed(flat[b])) else: outs = [] for b in range(B_total): outs.append(self._forward_single(flat[b, 0], 0).unsqueeze(0)) # Single-frame output shape: single_out_shape = ( tuple(self.image_shape) if self.smaps is not None else (n_coils, *self.image_shape) ) stacked = torch.stack(outs, dim=0) # (B_total, S_total, *single_out) return stacked.reshape(*B_shape, *s_shape, *single_out_shape)
def _forward_single(self, sparse_kspace: torch.Tensor, s_flat_idx: int = 0): """Single-frame forward (one stack element). ``sparse_kspace`` shape: ``(n_coils, n_samples)``.""" n_coils = sparse_kspace.shape[0] src_device = sparse_kspace.device comp_device = self.device if self.device is not None else src_device use_pipeline = comp_device.type == "cuda" and src_device.type == "cpu" dtype = sparse_kspace.dtype idx_s, sqw_s, sp_s, _ = self._stack_arrays(s_flat_idx) indices = idx_s.to(comp_device) sqrt_w = sqw_s.to(comp_device) sort_perm = sp_s.to(comp_device) self._ensure_bins(comp_device) # Pre-allocate reusable grid buffer (one alloc, reused per coil) grid = torch.empty(self.grid_size, dtype=dtype, device=comp_device) if self.smaps is not None: conj_smaps = self._conj_smaps.to(comp_device, dtype=dtype) accum = torch.zeros(self.image_shape, dtype=dtype, device=comp_device) else: accum = torch.zeros( (n_coils, *self.image_shape), dtype=dtype, device=comp_device ) if use_pipeline: self._forward_pipeline( sparse_kspace, indices, sqrt_w, sort_perm, grid, conj_smaps if self.smaps is not None else None, accum, dtype, ) else: for c in range(n_coils): coil_data = sparse_kspace[c].to(comp_device)[sort_perm] img_c = self._scatter_ifft_crop(coil_data, indices, sqrt_w, grid, dtype) if self.smaps is not None: # Fused multiply-accumulate: accum += img_c * conj_smaps[c] accum.addcmul_(img_c, conj_smaps[c]) else: accum[c] = img_c return accum.to(src_device) # ------------------------------------------------------------------ # Packed-stack forward — one C++ scatter call covers all S elements # ------------------------------------------------------------------ def _forward_packed(self, sparse_kspace: torch.Tensor) -> torch.Tensor: """Stacked forward over all stack elements at once. Parameters ---------- sparse_kspace : torch.Tensor ``(S_total, n_coils, n_per)`` complex. Returns ------- torch.Tensor ``(S_total, *image_shape)`` if smaps are set (SENSE-combined), else ``(S_total, n_coils, *image_shape)``. """ src_device = sparse_kspace.device comp_device = self.device if self.device is not None else src_device dtype = sparse_kspace.dtype idx_p, sqw_p, sp_p, _, S_total, _n_per = self._packed_arrays(comp_device) n_coils = int(sparse_kspace.shape[-2]) # Bring data on compute device, fold (S, n_coils, n_per) -> (n_coils, S*n_per) ksp_d = sparse_kspace.to(comp_device).permute(1, 0, 2).reshape(n_coils, -1) # Globally sort once via packed_sort_perm (one indexing op per coil). sorted_all = ksp_d[:, sp_p] # (n_coils, S*n_per) super_grid = torch.empty( S_total * self.grid_size, dtype=dtype, device=comp_device ) if self.smaps is not None: conj_smaps = self._conj_smaps.to(comp_device, dtype=dtype) accum = torch.zeros( S_total, *self.image_shape, dtype=dtype, device=comp_device ) else: accum = torch.zeros( S_total, n_coils, *self.image_shape, dtype=dtype, device=comp_device ) for c in range(n_coils): super_grid.zero_() # Single C++ kernel call covers every stack element. _scatter_add(super_grid, sorted_all[c], idx_p, sqw_p) full_imgs = ifft( super_grid.reshape(S_total, *self.grid_shape), axes=self.fft_axes, ) imgs = resize(full_imgs, (S_total, *self.image_shape)) if self.smaps is not None: accum.addcmul_(imgs, conj_smaps[c].unsqueeze(0)) else: accum[:, c] = imgs return accum.to(src_device) # ------------------------------------------------------------------ # Adjoint: image -> sparse k-space (forward NUFFT direction) # ------------------------------------------------------------------
[docs] @with_torch def forward(self, image: torch.Tensor) -> torch.Tensor: """Image to sparse k-space. Accepted input layouts (with optional leading ``*B`` batch and, for stacked plans, leading ``*S`` stack axes): - ``(*B, *S, *image_shape)`` if smaps are set - ``(*B, *S, n_coils, *image_shape)`` otherwise Output: ``(*B, *S, n_coils, *natural_shape)``. """ out = self._adjoint_flat(image) nat = self.natural_shape if len(nat) > 1: out = out.reshape(*out.shape[:-1], *nat) return out
@with_torch def _adjoint_flat(self, image: torch.Tensor) -> torch.Tensor: """Flat-output adjoint: returns ``(*B, *S, n_coils, n_samples)``.""" s_shape = self.stack_shape s_ndim = len(s_shape) # Single-frame ndim (no batch, no stack): single_ndim = len(self.image_shape) + (0 if self.smaps is not None else 1) prefix = tuple(int(s) for s in image.shape[: image.ndim - single_ndim]) if s_ndim: if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape: raise ValueError( f"image prefix {prefix} must end with stack_shape {s_shape}" ) B_shape = prefix[:-s_ndim] else: B_shape = prefix # Single-frame fast path. if not prefix: return self._adjoint_single(image, 0) S_total = int(np.prod(s_shape)) if s_shape else 1 B_total = int(np.prod(B_shape)) if B_shape else 1 # Reshape image to (B_total, S_total, *single_shape) single_shape = tuple(image.shape[image.ndim - single_ndim :]) flat = image.reshape(B_total, S_total, *single_shape) if s_ndim: outs = [self._adjoint_packed(flat[b]) for b in range(B_total)] else: outs = [ self._adjoint_single(flat[b, 0], 0).unsqueeze(0) for b in range(B_total) ] # Each entry has shape (S_total, n_coils, n_samples) n_coils = outs[0].shape[1] stacked = torch.stack(outs, dim=0) # (B_total, S_total, n_coils, n_samples) return stacked.reshape(*B_shape, *s_shape, n_coils, self.n_samples) def _adjoint_single(self, image: torch.Tensor, s_flat_idx: int = 0): """Single-frame adjoint (one stack element).""" src_device = image.device comp_device = self.device if self.device is not None else src_device use_pipeline = comp_device.type == "cuda" and src_device.type == "cpu" dtype = image.dtype idx_s, sqw_s, _, ip_s = self._stack_arrays(s_flat_idx) indices = idx_s.to(comp_device) sqrt_w = sqw_s.to(comp_device) inv_perm = ip_s.to(comp_device) image_d = image.to(comp_device) if self.smaps is not None: smaps = self.smaps.to(comp_device, dtype=dtype) n_coils = smaps.shape[0] else: n_coils = image_d.shape[0] n_samples = indices.shape[0] output = torch.zeros(n_coils, n_samples, dtype=dtype, device=comp_device) # Pre-allocate reusable padded grid buffer padded = torch.empty(*self.grid_shape, dtype=dtype, device=comp_device) if use_pipeline and self.smaps is not None: self._adjoint_pipeline( image_d, indices, sqrt_w, inv_perm, smaps, padded, output, dtype, ) else: for c in range(n_coils): coil_img = image_d * smaps[c] if self.smaps is not None else image_d[c] output[c] = self._fft_pad_gather( coil_img, indices, sqrt_w, inv_perm, padded, dtype, ) return output.to(src_device) # ------------------------------------------------------------------ # Packed-stack adjoint — one C++ gather call covers all S elements # ------------------------------------------------------------------ def _adjoint_packed(self, image: torch.Tensor) -> torch.Tensor: """Stacked adjoint over all stack elements at once. Parameters ---------- image : torch.Tensor ``(S_total, *image_shape)`` if smaps are set, otherwise ``(S_total, n_coils, *image_shape)``. Returns ------- torch.Tensor ``(S_total, n_coils, n_per)`` complex. """ src_device = image.device comp_device = self.device if self.device is not None else src_device dtype = image.dtype idx_p, sqw_p, _, ip_p, S_total, n_per = self._packed_arrays(comp_device) image_d = image.to(comp_device) if self.smaps is not None: smaps = self.smaps.to(comp_device, dtype=dtype) n_coils = smaps.shape[0] else: n_coils = image_d.shape[1] # Pre-allocate a packed (S, *grid_shape) zero-padded buffer reused # per coil; one batched FFT per coil over the full stack. padded = torch.zeros(S_total, *self.grid_shape, dtype=dtype, device=comp_device) slc = (slice(None), *self._pad_slices) # Output is built sorted in packed layout, then unsorted via inv_perm # at the end (single indexing op per coil). output_unsorted = torch.empty( n_coils, S_total * n_per, dtype=dtype, device=comp_device ) for c in range(n_coils): coil_imgs = image_d * smaps[c] if self.smaps is not None else image_d[:, c] padded.zero_() padded[slc] = coil_imgs kgrid = fft(padded, axes=self.fft_axes) # (S, *grid_shape) super_flat = kgrid.reshape(-1) # (S*grid_size,) # Single C++ kernel call across the whole stack. sorted_packed = _gather(super_flat, idx_p, sqw_p) # (S*n_per,) # Undo per-stack sort with one indexing op. output_unsorted[c] = sorted_packed[ip_p] out = ( output_unsorted.reshape(n_coils, S_total, n_per) .permute(1, 0, 2) .contiguous() ) return out.to(src_device) def __call__(self, x, adjoint=False): if adjoint: return self.adjoint(x) return self.forward(x) # ------------------------------------------------------------------ # Normal operator: A^H A # ------------------------------------------------------------------
[docs] @with_torch def normal(self, image: torch.Tensor) -> torch.Tensor: """Self-adjoint application: ``A^H A image``. When ``self.toeplitz`` is True, uses a pre-computed PSF on ``grid_shape`` (built lazily on first call) and applies pad / FFT / PSF / IFFT / crop per coil. Otherwise dispatches to a fused ``forward(adjoint(.))`` helper that *skips* the ``inv_perm`` (in adjoint) and ``sort_perm`` (in forward) round- trip — the intermediate sparse k-space is kept in sorted order end-to-end ("sort-once" optimisation). Parameters ---------- image : torch.Tensor Same shape as the output of :meth:`forward`. Returns ------- torch.Tensor Same shape as input. """ if self.toeplitz: if self._toep_op is None: # Local import to avoid circular import at module load. from .._toep._grog_toep import GrogToeplitzOp self._toep_op = GrogToeplitzOp(self, device=self.device) return self._toep_op(image) return self._normal_no_toep(image)
# ------------------------------------------------------------------ # Sort-once non-Toeplitz normal # ------------------------------------------------------------------ @with_torch def _normal_no_toep(self, image: torch.Tensor) -> torch.Tensor: """Batched/stacked dispatcher for :meth:`_normal_*_no_toep`. Mirrors the dispatch logic of :meth:`forward` / :meth:`adjoint` but routes per-frame work to fused helpers that omit the ``inv_perm`` / ``sort_perm`` round-trip in the intermediate sparse-k-space buffer. """ s_shape = self.stack_shape s_ndim = len(s_shape) single_ndim = ( len(self.image_shape) if self.smaps is not None else len(self.image_shape) + 1 ) prefix = tuple(int(s) for s in image.shape[: image.ndim - single_ndim]) if s_ndim: if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape: raise ValueError( f"image prefix {prefix} must end with stack_shape {s_shape}" ) B_shape = prefix[:-s_ndim] else: B_shape = prefix if not prefix: return self._normal_single_no_toep(image, 0) S_total = int(np.prod(s_shape)) if s_shape else 1 B_total = int(np.prod(B_shape)) if B_shape else 1 single_shape = tuple(image.shape[image.ndim - single_ndim :]) flat = image.reshape(B_total, S_total, *single_shape) outs = [] if s_ndim: for b in range(B_total): outs.append(self._normal_packed_no_toep(flat[b])) else: for b in range(B_total): outs.append(self._normal_single_no_toep(flat[b, 0], 0).unsqueeze(0)) stacked = torch.stack(outs, dim=0) return stacked.reshape(*B_shape, *s_shape, *single_shape) def _normal_single_no_toep(self, image: torch.Tensor, s_flat_idx: int = 0): """Sort-once single-frame ``A^H A``; intermediate stays sorted.""" src_device = image.device comp_device = self.device if self.device is not None else src_device dtype = image.dtype idx_s, sqw_s, _, _ = self._stack_arrays(s_flat_idx) indices = idx_s.to(comp_device) sqrt_w = sqw_s.to(comp_device) self._ensure_bins(comp_device) image_d = image.to(comp_device) if self.smaps is not None: smaps = self.smaps.to(comp_device, dtype=dtype) n_coils = smaps.shape[0] accum = torch.zeros(self.image_shape, dtype=dtype, device=comp_device) else: n_coils = image_d.shape[0] accum = torch.zeros( (n_coils, *self.image_shape), dtype=dtype, device=comp_device, ) padded = torch.empty(*self.grid_shape, dtype=dtype, device=comp_device) grid = torch.empty(self.grid_size, dtype=dtype, device=comp_device) for c in range(n_coils): coil_img = image_d * smaps[c] if self.smaps is not None else image_d[c] # adjoint: pad -> FFT -> gather (sorted; skip [inv_perm]) padded.zero_() padded[self._pad_slices] = coil_img kgrid = fft(padded, axes=self.fft_axes).reshape(-1) sorted_kspace = _gather(kgrid, indices, sqrt_w) # forward: scatter (presorted; skip [sort_perm]) -> IFFT -> crop img_c = self._scatter_ifft_crop( sorted_kspace, indices, sqrt_w, grid, dtype, ) if self.smaps is not None: accum.addcmul_(img_c, smaps[c].conj()) else: accum[c] = img_c return accum.to(src_device) def _normal_packed_no_toep(self, image: torch.Tensor) -> torch.Tensor: """Sort-once packed-stack ``A^H A`` over all S elements at once.""" src_device = image.device comp_device = self.device if self.device is not None else src_device dtype = image.dtype idx_p, sqw_p, _, _, S_total, _n_per = self._packed_arrays(comp_device) image_d = image.to(comp_device) if self.smaps is not None: smaps = self.smaps.to(comp_device, dtype=dtype) n_coils = smaps.shape[0] accum = torch.zeros( S_total, *self.image_shape, dtype=dtype, device=comp_device, ) else: n_coils = image_d.shape[1] accum = torch.zeros( S_total, n_coils, *self.image_shape, dtype=dtype, device=comp_device, ) padded = torch.zeros(S_total, *self.grid_shape, dtype=dtype, device=comp_device) super_grid = torch.empty( S_total * self.grid_size, dtype=dtype, device=comp_device, ) slc = (slice(None), *self._pad_slices) for c in range(n_coils): coil_imgs = image_d * smaps[c] if self.smaps is not None else image_d[:, c] # adj packed: pad/FFT/gather (sorted-packed; skip [ip_p]) padded.zero_() padded[slc] = coil_imgs kgrid_super = fft(padded, axes=self.fft_axes).reshape(-1) sorted_packed = _gather(kgrid_super, idx_p, sqw_p) # (S*n_per,) # fwd packed presorted: scatter (skip [:, sp_p]) / IFFT / crop super_grid.zero_() _scatter_add(super_grid, sorted_packed, idx_p, sqw_p) full_imgs = ifft( super_grid.reshape(S_total, *self.grid_shape), axes=self.fft_axes, ) imgs = resize(full_imgs, (S_total, *self.image_shape)) if self.smaps is not None: accum.addcmul_(imgs, smaps[c].conj().unsqueeze(0)) else: accum[:, c] = imgs return accum.to(src_device) # ------------------------------------------------------------------ # Iterative solve # ------------------------------------------------------------------ # Provided by SolveMixin (attached at module import time). # ------------------------------------------------------------------ # Batch helpers for decorators (Tier-1 FFT fusion) # ------------------------------------------------------------------ def _scatter_ifft_crop_batch( self, batch_kspace: torch.Tensor, s_flat_idx: int = 0, ) -> torch.Tensor: """Scatter (per-component loop) → **one batched IFFT** → crop. Reduces ``B x n_coils`` IFFT calls to a single ``torch.fft.ifftn((B, *grid_shape))`` call. The scatter loop is unchanged (needs a batched C++ kernel for full fusion). Parameters ---------- batch_kspace : torch.Tensor ``(B, n_samples)`` complex, **unsorted**. Each row is an independently-weighted k-space vector (e.g. one ORC component x one coil, or one subspace frame x one coil). s_flat_idx : int Flattened stack-element index (ignored for unstacked plans). Returns ------- torch.Tensor ``(B, *image_shape)`` complex, on the same device as input. """ B = batch_kspace.shape[0] src_device = batch_kspace.device comp_device = self.device if self.device is not None else src_device dtype = batch_kspace.dtype idx_s, sqw_s, sp_s, _ = self._stack_arrays(s_flat_idx) indices = idx_s.to(comp_device) sqrt_w = sqw_s.to(comp_device) sort_perm = sp_s.to(comp_device) self._ensure_bins(comp_device) # Sort all B inputs at once: (B, n_samples) — one indexing op sorted_ksp = batch_kspace.to(comp_device)[:, sort_perm] # Scatter B times — C++ kernel, bottleneck until batched kernel exists grids = torch.zeros(B, self.grid_size, dtype=dtype, device=comp_device) for b in range(B): self._scatter(grids[b], sorted_ksp[b], indices, sqrt_w) # ONE batched IFFT + center-crop over last ``ndim`` dims # axes=(-k,...,-1) work correctly on any batch prefix grids_nd = grids.reshape(B, *self.grid_shape) # Full-size IFFT first, then center-crop image; cropping in k-space # (via oshape) would be wrong when the grid is oversampled. full_imgs = ifft(grids_nd, axes=self.fft_axes) # (B, *grid_shape) imgs = resize(full_imgs, (B, *self.image_shape)) return imgs.to(src_device) def _fft_pad_gather_batch( self, batch_imgs: torch.Tensor, s_flat_idx: int = 0, ) -> torch.Tensor: """ONE batched FFT -> zero-pad -> gather (per-component loop). Reduces ``B x n_coils`` FFT calls to a single ``torch.fft.fftn((B, *image_shape))`` call. The gather loop is unchanged (needs a batched C++ kernel for full fusion). Parameters ---------- batch_imgs : torch.Tensor ``(B, *image_shape)`` complex. Each slice ``[b]`` is an independently-weighted image (e.g. one ORC component x one coil). s_flat_idx : int Flattened stack-element index (ignored for unstacked plans). Returns ------- torch.Tensor ``(B, n_samples)`` complex, in original (unsorted) k-space order. """ B = batch_imgs.shape[0] src_device = batch_imgs.device comp_device = self.device if self.device is not None else src_device dtype = batch_imgs.dtype idx_s, sqw_s, _, ip_s = self._stack_arrays(s_flat_idx) indices = idx_s.to(comp_device) sqrt_w = sqw_s.to(comp_device) inv_perm = ip_s.to(comp_device) imgs_d = batch_imgs.to(comp_device) # Zero-pad images in image space (adjoint of center-crop in image space) padded = torch.zeros(B, *self.grid_shape, dtype=dtype, device=comp_device) padded[(slice(None), *self._pad_slices)] = imgs_d # ONE batched FFT at full grid_shape fft_results = fft(padded, axes=self.fft_axes) # (B, *grid_shape) padded_flat = fft_results.reshape(B, -1) # Gather B times — C++ kernel, bottleneck until batched kernel exists output = torch.stack( [_gather(padded_flat[b], indices, sqrt_w)[inv_perm] for b in range(B)] ) return output.to(src_device) # (B, n_samples) # ------------------------------------------------------------------ # Core single-coil ops (fused) # ------------------------------------------------------------------ def _scatter_ifft_crop(self, coil_data, indices, sqrt_w, grid, _dtype): """Scatter -> full-size IFFT -> image-space center-crop for one coil. The IFFT is performed at the full (oversampled) ``grid_shape``; the result is then center-cropped in **image space** to ``image_shape``. Cropping k-space first (the old behaviour) is wrong when ``grid_shape != image_shape`` (oversampling > 1). """ grid.zero_() self._scatter(grid, coil_data, indices, sqrt_w) full_img = ifft(grid.reshape(self.grid_shape), axes=self.fft_axes) return resize(full_img, self.image_shape) def _fft_pad_gather(self, coil_img, indices, sqrt_w, inv_perm, padded, _dtype): """Zero-pad image -> FFT at grid_shape -> gather -> unpermute for one coil. Adjoint of: scatter -> IFFT at grid_shape -> center-crop image. Reuses the *padded* buffer to avoid a fresh allocation per coil. """ # Zero-pad image in image space (adjoint of center-crop in image space) padded.zero_() padded[self._pad_slices] = coil_img fft_result = fft(padded, axes=self.fft_axes) # FFT at full grid_shape return _gather(fft_result.reshape(-1), indices, sqrt_w)[inv_perm] # ------------------------------------------------------------------ # Dual-stream GPU pipelining (CPU input -> GPU compute) # ------------------------------------------------------------------ def _forward_pipeline( self, sparse_kspace, indices, sqrt_w, sort_perm, grid, conj_smaps, accum, dtype ): """Dual-stream forward: overlap H2D transfer with FFT.""" n_coils = sparse_kspace.shape[0] device = accum.device s1 = torch.cuda.Stream(device=device) s2 = torch.cuda.Stream(device=device) buf = [ sparse_kspace[0].pin_memory().to(device, non_blocking=True), None, ] for c in range(n_coils): cur = c % 2 nxt = 1 - cur stream = s1 if cur == 0 else s2 if c + 1 < n_coils: other = s1 if nxt == 0 else s2 with torch.cuda.stream(other): buf[nxt] = ( sparse_kspace[c + 1].pin_memory().to(device, non_blocking=True) ) with torch.cuda.stream(stream): coil_data = buf[cur][sort_perm] img_c = self._scatter_ifft_crop(coil_data, indices, sqrt_w, grid, dtype) if conj_smaps is not None: accum.addcmul_(img_c, conj_smaps[c]) else: accum += img_c.abs().square() torch.cuda.synchronize(device) def _adjoint_pipeline( self, image_d, indices, sqrt_w, inv_perm, smaps, padded, output, dtype ): """Dual-stream adjoint: overlap FFT with D2H transfer.""" n_coils = smaps.shape[0] device = image_d.device s1 = torch.cuda.Stream(device=device) s2 = torch.cuda.Stream(device=device) for c in range(n_coils): stream = s1 if c % 2 == 0 else s2 with torch.cuda.stream(stream): coil_img = image_d * smaps[c] output[c] = self._fft_pad_gather( coil_img, indices, sqrt_w, inv_perm, padded, dtype, ) torch.cuda.synchronize(device)