import contextlib
"""Subspace projection gadget and SparseFFT decorator.
Provides two complementary views of low-rank temporal/contrast subspace
compression:
* :class:`SubspaceProjection` — standalone projection via truncated SVD,
operates on dense (n_frames, *spatial) tensors.
* :func:`with_subspace` / :class:`SubspaceSparseFFT` — decorator that wraps
a :class:`~pygrog.operator.SparseFFT` and fuses the subspace projection
directly into the k-space ↔ image transform.
The subspace basis ``Phi`` has shape ``(K, T)`` where ``K`` is the subspace
rank and ``T`` is the number of temporal frames or contrasts.
Data conventions::
Sparse k-space: (*batch, n_coils, *natural_shape) — natural_shape comes
from the GROG plan, e.g. (T, k1, k0, kw) for 3D MRF.
Image space: (*batch_image, K, *image_shape) — K subspace coefficients.
The ``encoding_axis`` argument identifies which axis of the sparse tensor
carries the temporal/contrast dimension ``T``; the gadget broadcasts the
basis along that axis.
"""
__all__ = [
"SubspaceMaskedFFT",
"SubspaceProjection",
"SubspaceSparseFFT",
"with_subspace",
]
import torch
import numpy as np
from mrinufft._array_compat import with_torch
from .._solve._mixin import SolveMixin
# =====================================================================
# Standalone gadget
# =====================================================================
class SubspaceProjection:
"""Low-rank temporal subspace projection via truncated SVD.
Given multi-frame data ``(n_frames, *spatial)``, projects onto the
leading ``n_components`` left singular vectors.
Parameters
----------
n_components : int
Number of subspace components to retain.
"""
def __init__(self, n_components: int):
self.n_components = n_components
self._basis = None
def fit(self, calib_data: torch.Tensor) -> "SubspaceProjection":
U, _S, _Vh = torch.linalg.svd(calib_data, full_matrices=False)
self._basis = U[:, : self.n_components].T.conj()
return self
@property
def basis(self) -> torch.Tensor:
if self._basis is None:
raise RuntimeError("Call fit() first.")
return self._basis
@with_torch
def forward(self, data: torch.Tensor) -> torch.Tensor:
spatial_shape = data.shape[1:]
flat = data.reshape(data.shape[0], -1)
coeff = self.basis @ flat
return coeff.reshape(self.n_components, *spatial_shape)
@with_torch
def adjoint(self, coefficients: torch.Tensor) -> torch.Tensor:
spatial_shape = coefficients.shape[1:]
flat = coefficients.reshape(self.n_components, -1)
frames = self.basis.conj().T @ flat
return frames.reshape(-1, *spatial_shape)
# =====================================================================
# SparseFFT decorator
# =====================================================================
[docs]
def with_subspace(base_op, subspace_basis, encoding_axis: int = -4, *, toeplitz=None):
"""Wrap a SparseFFT or MaskedFFT operator with subspace projection.
Parameters
----------
base_op : SparseFFT | MaskedFFT
Underlying operator with a multi-dim ``natural_shape`` containing
the temporal axis.
subspace_basis : array-like, complex
``(K, T)`` subspace basis matrix.
encoding_axis : int
Axis (in the full sparse-tensor layout) carrying ``T``. Default
``-4`` matches ``(*batch, C, T, k1, k0, kw)``.
toeplitz : bool | None, optional
Use Toeplitz embedding for :meth:`normal`. ``None`` inherits
from ``base_op.toeplitz``.
"""
from ..operator._masked_fft import MaskedFFT
if isinstance(base_op, MaskedFFT):
return SubspaceMaskedFFT(
base_op,
subspace_basis,
encoding_axis=encoding_axis,
toeplitz=toeplitz,
)
return SubspaceSparseFFT(
base_op,
subspace_basis,
encoding_axis=encoding_axis,
toeplitz=toeplitz,
)
class SubspaceSparseFFT(SolveMixin):
"""SparseFFT with low-rank subspace projection.
Adjoint (sparse → image), per coil, per ``k`` (default ``k_chunk=1``):
1. weight by ``sqrt_w`` once on the input;
2. multiply by ``basis[k]`` along the T axis;
3. scatter into a single reused oversampled grid, IFFT, center-crop;
4. fused FMA with ``smaps[c].conj()`` into ``output[k]``.
Forward (image → sparse), per coil, per ``k``:
1. multiply ``coeffs[k]`` by ``smaps[c]``;
2. zero-pad into a single reused oversampled buffer, FFT, gather;
3. accumulate ``basis.conj()[k] * gathered`` into ``ksp_c``;
4. write ``ksp_c * sqrt_w`` into the output coil slot.
Setting ``k_chunk > 1`` routes the inner loop through batched
``_scatter_ifft_crop_batch`` / ``_fft_pad_gather_batch`` helpers (one
FFT call over ``k_chunk`` planes), trading VRAM for FFT throughput.
Parameters
----------
base_op : SparseFFT
Must have a multi-dim ``natural_shape`` covering the sparse layout
(e.g. ``(T, k1, k0, kw)``) and SENSE maps (``smaps``) attached.
subspace_basis : torch.Tensor
``(K, T)`` complex basis.
encoding_axis : int
Axis (in full sparse layout) of the temporal dimension ``T``.
Default ``-4`` (last four axes are natural ``(T, k1, k0, kw)``).
k_chunk : int, optional
Number of subspace components processed per batched FFT. Default
``1`` (lowest VRAM, one grid/padded buffer reused across all
``(coil, k)`` pairs).
"""
def __init__(
self,
base_op,
subspace_basis,
encoding_axis: int = -4,
*,
toeplitz=None,
k_chunk: int = 1,
):
self._base = base_op
self.basis = torch.as_tensor(subspace_basis) # (K, T)
self.K, self.T = self.basis.shape
self.encoding_axis = encoding_axis
# K-batching factor for the inner (FFT, scatter/gather) loop.
# Default 1 = one grid/padded buffer reused across the K subspace
# components (lowest VRAM, restores pre-regression behaviour).
# Larger values trade memory for batched FFT throughput by routing
# through `_scatter_ifft_crop_batch` / `_fft_pad_gather_batch`.
self.k_chunk = max(1, int(k_chunk))
self.grid_shape = base_op.grid_shape
self.image_shape = base_op.image_shape
self.smaps = getattr(base_op, "smaps", None)
# Position of T inside `natural_shape` (positive index).
nat_ndim = len(base_op.natural_shape)
# Full sparse layout: (*batch, C, *natural). encoding_axis is given
# relative to that layout; we need the position inside `natural`.
# E.g. encoding_axis=-4, nat_ndim=4 → axis_in_nat = -4 + nat_ndim = 0 ✓.
ax = encoding_axis if encoding_axis >= 0 else encoding_axis + (1 + nat_ndim)
# `ax` now indexes (C, *natural); subtract the leading C dim.
self._t_axis_in_nat = ax - 1
if not (0 <= self._t_axis_in_nat < nat_ndim):
raise ValueError(
f"encoding_axis={encoding_axis} does not land inside natural_shape "
f"{base_op.natural_shape} (computed nat-axis {self._t_axis_in_nat})"
)
if base_op.natural_shape[self._t_axis_in_nat] != self.T:
raise ValueError(
f"basis T={self.T} does not match natural_shape"
f"[{self._t_axis_in_nat}]={base_op.natural_shape[self._t_axis_in_nat]}"
)
# Toeplitz flag inherits from base unless overridden.
if toeplitz is None:
toeplitz = bool(getattr(base_op, "toeplitz", False))
self.toeplitz = bool(toeplitz)
self._toep_op = None # lazily built
# Cache for sort-hoisted T-index lookup (per stack element + device).
# ``t_idx[i] = T-coordinate of natural sample i``;
# ``t_idx_sorted = t_idx[sort_perm]`` indexes basis rows in the
# sorted sample order used by the scatter kernel. Avoids a
# 1.3 GB random-access gather per (coil, k) iteration.
self._t_idx_cache: dict = {}
def _get_t_idx_sorted(self, sort_perm: torch.Tensor, s_flat_idx: int):
"""Return cached sorted T-index for the current stack & device."""
key = (s_flat_idx, sort_perm.device, int(sort_perm.numel()))
cached = self._t_idx_cache.get(key)
if cached is not None:
return cached
nat = self._base.natural_shape
nat_ndim = len(nat)
T = self.T
# PyTorch indexing requires long/int32 (no int16) → pick int32 when
# T fits, else int64.
idx_dtype = torch.int32 if T <= 2_147_483_647 else torch.int64
view_shape = [1] * nat_ndim
view_shape[self._t_axis_in_nat] = T
t_axis = (
torch.arange(T, dtype=idx_dtype, device=sort_perm.device)
.view(view_shape)
.expand(nat)
.reshape(-1)
)
# Materialise then permute (small one-shot cost; cached afterwards).
t_idx_sorted = t_axis[sort_perm].contiguous()
self._t_idx_cache[key] = t_idx_sorted
return t_idx_sorted
# ------------------------------------------------------------------
# adjoint: sparse k-space → subspace coefficient images (A^H)
# ------------------------------------------------------------------
@with_torch
def adjoint(self, sparse_kspace: torch.Tensor) -> torch.Tensor:
"""Sparse k-space → subspace coefficient images (``A^H``)."""
return self._adjoint_impl(sparse_kspace)
# ------------------------------------------------------------------
# forward: subspace coefficient images → sparse k-space (A)
# ------------------------------------------------------------------
@with_torch
def forward(self, coeffs: torch.Tensor) -> torch.Tensor:
return self._forward_impl(coeffs)
# ==================================================================
# implementation
# ==================================================================
def _adjoint_impl(self, sparse_kspace: torch.Tensor) -> torch.Tensor:
"""Sparse → coefficients.
Accepted layouts:
- ``(*B, *S, C, *natural)`` (general)
- ``(C, *natural)`` (single frame, no batch / stack)
Output: ``(*B, *S, K, *image_shape)`` (or ``(K, *image_shape)``
for the no-batch / no-stack case).
"""
base = self._base
nat = base.natural_shape
nat_ndim = len(nat)
s_shape = tuple(getattr(base, "stack_shape", ()) or ())
s_ndim = len(s_shape)
# Identify leading prefix (*B, *S) before (C, *natural).
prefix = tuple(
int(s) for s in sparse_kspace.shape[: sparse_kspace.ndim - (1 + nat_ndim)]
)
if s_ndim:
if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape:
raise ValueError(
f"sparse_kspace prefix {prefix} must end with stack_shape {s_shape}"
)
B_shape = prefix[:-s_ndim]
else:
B_shape = prefix
if sparse_kspace.ndim < 1 + nat_ndim:
raise ValueError(
f"Expected (...{(1 + nat_ndim)}D)=(C, *natural)={('C', *tuple(nat))}; "
f"got {tuple(sparse_kspace.shape)}"
)
# No batch, no stack → single-frame fast path.
if not prefix:
return self._adjoint_single(sparse_kspace, 0)
B_total = int(np.prod(B_shape)) if B_shape else 1
S_total = int(np.prod(s_shape)) if s_shape else 1
flat = sparse_kspace.reshape(
B_total, S_total, *sparse_kspace.shape[-(1 + nat_ndim) :]
)
outs = []
for b in range(B_total):
for s in range(S_total):
outs.append(self._adjoint_single(flat[b, s], s))
# outs[i]: (K, *image_shape)
stacked = torch.stack(outs, dim=0)
return stacked.reshape(*B_shape, *s_shape, self.K, *base.image_shape)
def _adjoint_single(self, sparse_kspace: torch.Tensor, s_flat_idx: int = 0):
"""Single-frame, single-stack-element adjoint. Input: ``(C, *natural)``."""
base = self._base
# Dispatch to dual-stream pipeline when input lives on CPU but the
# base operator computes on CUDA. Overlaps per-coil H2D with the
# per-K scatter+IFFT compute.
if self._use_dual_stream(sparse_kspace):
return self._adjoint_single_dual(sparse_kspace, s_flat_idx)
nat = base.natural_shape
nat_ndim = len(nat)
device = sparse_kspace.device
dtype = sparse_kspace.dtype
n_coils = int(sparse_kspace.shape[0])
if base.smaps is None:
raise NotImplementedError("SubspaceSparseFFT requires base_op.smaps")
smaps = base.smaps.to(device, dtype=dtype)
basis = self.basis.to(device, dtype=dtype) # (K, T)
T = self.T
K = self.K
phi_shape = [1] * nat_ndim
phi_shape[self._t_axis_in_nat] = T
output = torch.zeros(K, *base.image_shape, dtype=dtype, device=device)
# Per-stack arrays + natural-order pre-weights.
idx_s, sqw_s, sp_s, ip_s = base._stack_arrays(s_flat_idx)
indices = idx_s.to(device)
sqrt_w = sqw_s.to(device)
sort_perm = sp_s.to(device)
pre_w = sqw_s[ip_s].to(device=device, dtype=dtype).view(*nat)
base._ensure_bins(device)
k_chunk = self.k_chunk
if k_chunk <= 1:
# Lowest-VRAM path: one reused grid buffer across all (c, k).
# Sort hoisting: the heavy ``flat[sort_perm]`` (1.3 GB random
# gather) is done ONCE per coil; per-k work is a small T-indexed
# basis lookup + element-wise multiply on the already-sorted
# data. Brings adjoint cost in line with forward.
grid = torch.empty(base.grid_size, dtype=dtype, device=device)
t_idx_sorted = self._get_t_idx_sorted(sort_perm, s_flat_idx)
for c in range(n_coils):
# ``sw_c_sorted`` is in scatter-sorted order (no per-k gather).
sw_c_sorted = (sparse_kspace[c] * pre_w).reshape(-1)[sort_perm]
smap_conj_c = smaps[c].conj()
for k in range(K):
# T-indexed lookup → (n_samples,) basis values, then
# element-wise multiply with the sorted weighted samples.
weighted_sorted = basis[k][t_idx_sorted] * sw_c_sorted
img_k = base._scatter_ifft_crop(
weighted_sorted, indices, sqrt_w, grid, dtype
)
output[k].addcmul_(img_k, smap_conj_c)
else:
# K-batched path (high VRAM, batched FFT): chunk over k.
for c in range(n_coils):
sw_c = sparse_kspace[c] * pre_w
smap_conj_c = smaps[c].conj().unsqueeze(0)
for k0 in range(0, K, k_chunk):
k1 = min(k0 + k_chunk, K)
kb = k1 - k0
weighted = basis[k0:k1].view(kb, *phi_shape) * sw_c.unsqueeze(0)
weighted_flat = weighted.reshape(kb, -1)
imgs = base._scatter_ifft_crop_batch(
weighted_flat, s_flat_idx=s_flat_idx
)
output[k0:k1].addcmul_(imgs, smap_conj_c)
return output
def _forward_impl(self, coeffs: torch.Tensor) -> torch.Tensor:
"""Coefficients → sparse.
Accepted layouts:
- ``(*B, *S, K, *image_shape)`` (general)
- ``(K, *image_shape)`` (single frame, no batch / stack)
Output: ``(*B, *S, C, *natural)`` (or ``(C, *natural)`` for the
no-batch / no-stack case).
"""
base = self._base
nat = base.natural_shape
s_shape = tuple(getattr(base, "stack_shape", ()) or ())
s_ndim = len(s_shape)
img_ndim = len(base.image_shape)
single_ndim = 1 + img_ndim # K + *image_shape
prefix = tuple(int(s) for s in coeffs.shape[: coeffs.ndim - single_ndim])
if s_ndim:
if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape:
raise ValueError(
f"coeffs prefix {prefix} must end with stack_shape {s_shape}"
)
B_shape = prefix[:-s_ndim]
else:
B_shape = prefix
if not prefix:
return self._forward_single(coeffs, 0)
B_total = int(np.prod(B_shape)) if B_shape else 1
S_total = int(np.prod(s_shape)) if s_shape else 1
flat = coeffs.reshape(B_total, S_total, *coeffs.shape[-single_ndim:])
outs = []
for b in range(B_total):
for s in range(S_total):
outs.append(self._forward_single(flat[b, s], s))
# outs[i]: (C, *nat)
n_coils = outs[0].shape[0]
stacked = torch.stack(outs, dim=0)
return stacked.reshape(*B_shape, *s_shape, n_coils, *nat)
def _forward_single(
self, coeffs: torch.Tensor, s_flat_idx: int = 0
) -> torch.Tensor:
"""Single-frame, single-stack-element forward. Input: ``(K, *image_shape)``,
Output: ``(C, *natural)``."""
base = self._base
if self._use_dual_stream(coeffs):
return self._forward_single_dual(coeffs, s_flat_idx)
nat = base.natural_shape
nat_ndim = len(nat)
if coeffs.shape[0] != self.K:
raise ValueError(f"coeffs.shape[0]={coeffs.shape[0]} != K={self.K}")
if tuple(int(s) for s in coeffs.shape[1:]) != tuple(base.image_shape):
raise ValueError(
f"coeffs spatial {tuple(coeffs.shape[1:])} != image_shape {base.image_shape}"
)
device = coeffs.device
dtype = coeffs.dtype
if base.smaps is None:
raise NotImplementedError("SubspaceSparseFFT requires base_op.smaps")
smaps = base.smaps.to(device, dtype=dtype)
n_coils = int(smaps.shape[0])
basis_conj = self.basis.conj().to(device, dtype=dtype) # (K, T)
T = self.T
K = self.K
phi_shape = [1] * nat_ndim
phi_shape[self._t_axis_in_nat] = T
output = torch.empty(n_coils, *nat, dtype=dtype, device=device)
idx_s, sqw_s, _, ip_s = base._stack_arrays(s_flat_idx)
indices = idx_s.to(device)
sqrt_w = sqw_s.to(device)
inv_perm = ip_s.to(device)
pre_w = sqw_s[ip_s].to(device=device, dtype=dtype).view(*nat)
k_chunk = self.k_chunk
if k_chunk <= 1:
# Lowest-VRAM path: one reused padded buffer across all (c, k).
padded = torch.empty(*base.grid_shape, dtype=dtype, device=device)
for c in range(n_coils):
smap_c = smaps[c]
ksp_c = torch.zeros(*nat, dtype=dtype, device=device)
for k in range(K):
coil_img_k = coeffs[k] * smap_c # (*image)
gathered_k = base._fft_pad_gather(
coil_img_k, indices, sqrt_w, inv_perm, padded, dtype
)
ksp_c.add_(
basis_conj[k].view(*phi_shape) * gathered_k.reshape(*nat)
)
output[c] = ksp_c * pre_w
else:
# K-batched path (high VRAM, batched FFT): chunk over k.
for c in range(n_coils):
smap_c = smaps[c]
ksp_c = torch.zeros(*nat, dtype=dtype, device=device)
for k0 in range(0, K, k_chunk):
k1 = min(k0 + k_chunk, K)
kb = k1 - k0
coil_imgs = coeffs[k0:k1] * smap_c.unsqueeze(0)
gathered = base._fft_pad_gather_batch(
coil_imgs, s_flat_idx=s_flat_idx
)
gathered_nat = gathered.reshape(kb, *nat)
ksp_c.add_(
(basis_conj[k0:k1].view(kb, *phi_shape) * gathered_nat).sum(
dim=0
)
)
output[c] = ksp_c * pre_w
return output
# ------------------------------------------------------------------
# Dual-stream GPU pipeline (CPU input -> GPU compute, overlapped)
# ------------------------------------------------------------------
def _use_dual_stream(self, x: torch.Tensor) -> bool:
"""Return True when ``x`` lives on CPU and the base operator
computes on a CUDA device — the only configuration where coil-level
stream overlap pays off."""
base = self._base
comp = getattr(base, "device", None)
return (
comp is not None
and getattr(comp, "type", None) == "cuda"
and x.device.type == "cpu"
and torch.cuda.is_available()
)
def _adjoint_single_dual(
self, sparse_kspace: torch.Tensor, s_flat_idx: int = 0
) -> torch.Tensor:
"""Coil-pipelined adjoint: H2D of next coil overlapped with the
per-K scatter+IFFT+FMA of the current coil.
Input: ``(C, *natural)`` on CPU. Output: ``(K, *image_shape)`` on CPU
(mirrors the synchronous fast-path return device)."""
base = self._base
nat = base.natural_shape
nat_ndim = len(nat)
comp_device = base.device
n_coils = int(sparse_kspace.shape[0])
K, T = self.K, self.T
dtype = sparse_kspace.dtype
if base.smaps is None:
raise NotImplementedError("SubspaceSparseFFT requires base_op.smaps")
phi_shape = [1] * nat_ndim
phi_shape[self._t_axis_in_nat] = T
# Pre-stage constants once on the compute device.
basis_gpu = self.basis.to(comp_device, dtype=dtype) # (K, T)
idx_s, sqw_s, sp_s, ip_s = base._stack_arrays(s_flat_idx)
indices = idx_s.to(comp_device)
sqrt_w = sqw_s.to(comp_device)
sort_perm = sp_s.to(comp_device)
pre_w_gpu = sqw_s[ip_s].to(device=comp_device, dtype=dtype).view(*nat) # (*nat)
base._ensure_bins(comp_device)
smaps_cpu = base.smaps.to(dtype=dtype)
if not smaps_cpu.is_pinned() and smaps_cpu.device.type == "cpu":
with contextlib.suppress(RuntimeError):
smaps_cpu = smaps_cpu.pin_memory()
sparse_pin = sparse_kspace
if sparse_pin.device.type == "cpu" and not sparse_pin.is_pinned():
with contextlib.suppress(RuntimeError):
sparse_pin = sparse_pin.pin_memory()
output_gpu = torch.zeros(K, *base.image_shape, dtype=dtype, device=comp_device)
s_data = torch.cuda.Stream(device=comp_device)
s_comp = torch.cuda.Stream(device=comp_device)
# Single grid buffer reused across all (c, k) on the compute stream.
k_chunk = self.k_chunk
if k_chunk <= 1:
grid = torch.empty(base.grid_size, dtype=dtype, device=comp_device)
else:
grid = None
# Double-buffer the per-coil sparse + smaps slabs on the data stream.
buf_sparse: list[torch.Tensor | None] = [None, None]
buf_smaps: list[torch.Tensor | None] = [None, None]
with torch.cuda.stream(s_data):
buf_sparse[0] = sparse_pin[0].to(
comp_device, dtype=dtype, non_blocking=True
)
buf_smaps[0] = smaps_cpu[0].to(comp_device, dtype=dtype, non_blocking=True)
for c in range(n_coils):
cur = c % 2
nxt = 1 - cur
if c + 1 < n_coils:
with torch.cuda.stream(s_data):
buf_sparse[nxt] = sparse_pin[c + 1].to(
comp_device, dtype=dtype, non_blocking=True
)
buf_smaps[nxt] = smaps_cpu[c + 1].to(
comp_device, dtype=dtype, non_blocking=True
)
# Compute on s_comp; ensure cur transfer is visible there.
s_comp.wait_stream(s_data)
with torch.cuda.stream(s_comp):
smap_conj_c = buf_smaps[cur].conj()
if k_chunk <= 1:
# Sort hoisting: one big gather per coil; per-k is a
# T-indexed basis lookup + element-wise multiply.
t_idx_sorted = self._get_t_idx_sorted(sort_perm, s_flat_idx)
sw_c_sorted = (buf_sparse[cur] * pre_w_gpu).reshape(-1)[sort_perm]
for k in range(K):
weighted_sorted = basis_gpu[k][t_idx_sorted] * sw_c_sorted
img_k = base._scatter_ifft_crop(
weighted_sorted, indices, sqrt_w, grid, dtype
)
output_gpu[k].addcmul_(img_k, smap_conj_c)
else:
sw_c = buf_sparse[cur] * pre_w_gpu # (*nat)
smap_conj_b = smap_conj_c.unsqueeze(0)
for k0 in range(0, K, k_chunk):
k1 = min(k0 + k_chunk, K)
kb = k1 - k0
weighted = basis_gpu[k0:k1].view(
kb, *phi_shape
) * sw_c.unsqueeze(0)
weighted_flat = weighted.reshape(kb, -1)
imgs = base._scatter_ifft_crop_batch(
weighted_flat, s_flat_idx=s_flat_idx
)
output_gpu[k0:k1].addcmul_(imgs, smap_conj_b)
torch.cuda.synchronize(comp_device)
return output_gpu.to(sparse_kspace.device)
def _forward_single_dual(
self, coeffs: torch.Tensor, s_flat_idx: int = 0
) -> torch.Tensor:
"""Coil-pipelined forward: per-coil per-K FFT+gather on the
compute stream, async D2H of the result on the data stream.
Input: ``(K, *image_shape)`` on CPU. Output: ``(C, *natural)`` on CPU."""
base = self._base
nat = base.natural_shape
nat_ndim = len(nat)
comp_device = base.device
K, T = self.K, self.T
dtype = coeffs.dtype
if base.smaps is None:
raise NotImplementedError("SubspaceSparseFFT requires base_op.smaps")
phi_shape = [1] * nat_ndim
phi_shape[self._t_axis_in_nat] = T
basis_conj_gpu = self.basis.conj().to(comp_device, dtype=dtype)
idx_s, sqw_s, _, ip_s = base._stack_arrays(s_flat_idx)
indices = idx_s.to(comp_device)
sqrt_w = sqw_s.to(comp_device)
inv_perm = ip_s.to(comp_device)
pre_w_gpu = sqw_s[ip_s].to(device=comp_device, dtype=dtype).view(*nat)
# Coeffs are constant across coils — stage once.
coeffs_gpu = coeffs.to(comp_device, dtype=dtype, non_blocking=True)
smaps_cpu = base.smaps.to(dtype=dtype)
if smaps_cpu.device.type == "cpu" and not smaps_cpu.is_pinned():
with contextlib.suppress(RuntimeError):
smaps_cpu = smaps_cpu.pin_memory()
n_coils = int(smaps_cpu.shape[0])
# Pinned destination so D2H copy_(non_blocking=True) is truly async.
try:
output_cpu = torch.empty(n_coils, *nat, dtype=dtype, pin_memory=True)
except RuntimeError:
output_cpu = torch.empty(n_coils, *nat, dtype=dtype)
s_data = torch.cuda.Stream(device=comp_device)
s_comp = torch.cuda.Stream(device=comp_device)
# Single padded buffer reused across all (c, k) on the compute stream.
k_chunk = self.k_chunk
if k_chunk <= 1:
padded = torch.empty(*base.grid_shape, dtype=dtype, device=comp_device)
else:
padded = None
buf_smaps: list[torch.Tensor | None] = [None, None]
ksp_buf: list[torch.Tensor | None] = [None, None]
with torch.cuda.stream(s_data):
buf_smaps[0] = smaps_cpu[0].to(comp_device, dtype=dtype, non_blocking=True)
for c in range(n_coils):
cur = c % 2
nxt = 1 - cur
if c + 1 < n_coils:
with torch.cuda.stream(s_data):
buf_smaps[nxt] = smaps_cpu[c + 1].to(
comp_device, dtype=dtype, non_blocking=True
)
s_comp.wait_stream(s_data)
with torch.cuda.stream(s_comp):
smap_c = buf_smaps[cur]
ksp_c = torch.zeros(*nat, dtype=dtype, device=comp_device)
if k_chunk <= 1:
for k in range(K):
coil_img_k = coeffs_gpu[k] * smap_c
gathered_k = base._fft_pad_gather(
coil_img_k, indices, sqrt_w, inv_perm, padded, dtype
)
ksp_c.add_(
basis_conj_gpu[k].view(*phi_shape)
* gathered_k.reshape(*nat)
)
else:
for k0 in range(0, K, k_chunk):
k1 = min(k0 + k_chunk, K)
kb = k1 - k0
coil_imgs = coeffs_gpu[k0:k1] * smap_c.unsqueeze(0)
gathered = base._fft_pad_gather_batch(
coil_imgs, s_flat_idx=s_flat_idx
)
gathered_nat = gathered.reshape(kb, *nat)
ksp_c.add_(
(
basis_conj_gpu[k0:k1].view(kb, *phi_shape)
* gathered_nat
).sum(dim=0)
)
ksp_buf[cur] = ksp_c * pre_w_gpu
# Async D2H on the data stream while the next coil computes.
s_data.wait_stream(s_comp)
with torch.cuda.stream(s_data):
output_cpu[c].copy_(ksp_buf[cur], non_blocking=True)
torch.cuda.synchronize(comp_device)
return output_cpu.to(coeffs.device)
@with_torch
def normal(self, coeffs):
if self.toeplitz:
if self._toep_op is None:
from .._toep._sub_toep import SubspaceToeplitzOp
self._toep_op = SubspaceToeplitzOp(
self,
device=self._base.device,
)
return self._toep_op(coeffs)
return self._adjoint_impl(self._forward_impl(coeffs))
def __call__(self, x, adjoint=False):
if adjoint:
return self.adjoint(x)
return self.forward(x)
# =====================================================================
# MaskedFFT decorator
# =====================================================================
class SubspaceMaskedFFT(SolveMixin):
"""MaskedFFT with low-rank subspace projection (loop-fused).
Mirrors :class:`SubspaceSparseFFT` but operates on pre-gridded
k-space data via :class:`~pygrog.operator.MaskedFFT`.
Adjoint (gridded k-space → subspace coefficients), per coil:
1. for each ``k``: multiply by ``basis[k]`` along the T axis;
2. ONE batched K-IFFT + mask + center-crop;
3. fused FMA with ``smaps[c].conj()`` into the accumulator.
Forward (subspace coefficients → gridded k-space), per coil:
1. multiply coefficients by ``smaps[c]``;
2. ONE batched K-FFT + center-pad + mask;
3. for each ``k``: accumulate ``basis.conj()[k] * masked_grid``.
Parameters
----------
base_op : MaskedFFT
Must have sensitivity maps (``smaps``) attached and a multi-dim
``natural_shape`` covering the grid layout (e.g. ``(T, gy, gx)``
for a 2D+T acquisition).
subspace_basis : torch.Tensor
``(K, T)`` complex basis.
encoding_axis : int
Axis (in full grid layout) of the temporal dimension ``T``.
Default ``-3`` (last three axes are ``(T, gy, gx)`` for 2D).
"""
def __init__(
self, base_op, subspace_basis, encoding_axis: int = -3, *, toeplitz=None
):
self._base = base_op
self.basis = torch.as_tensor(subspace_basis) # (K, T)
self.K, self.T = self.basis.shape
self.encoding_axis = encoding_axis
self.grid_shape = base_op.grid_shape
self.image_shape = base_op.image_shape
self.smaps = getattr(base_op, "smaps", None)
# Position of T inside natural_shape (i.e. grid_shape for MaskedFFT).
nat_ndim = len(base_op.natural_shape)
ax = encoding_axis if encoding_axis >= 0 else encoding_axis + (1 + nat_ndim)
self._t_axis_in_nat = ax - 1
if not (0 <= self._t_axis_in_nat < nat_ndim):
raise ValueError(
f"encoding_axis={encoding_axis} does not land inside natural_shape "
f"{base_op.natural_shape} (computed nat-axis {self._t_axis_in_nat})"
)
if base_op.natural_shape[self._t_axis_in_nat] != self.T:
raise ValueError(
f"basis T={self.T} does not match natural_shape"
f"[{self._t_axis_in_nat}]={base_op.natural_shape[self._t_axis_in_nat]}"
)
if toeplitz is None:
toeplitz = bool(getattr(base_op, "toeplitz", False))
self.toeplitz = bool(toeplitz)
self._toep_op = None
# ------------------------------------------------------------------
# adjoint: gridded k-space → subspace coefficient images (A^H)
# ------------------------------------------------------------------
@with_torch
def adjoint(self, kspace_grid: torch.Tensor) -> torch.Tensor:
"""Gridded k-space → subspace coefficient images (``A^H``)."""
return self._adjoint_impl(kspace_grid)
@with_torch
def forward(self, coeffs: torch.Tensor) -> torch.Tensor:
"""Subspace coefficient images → gridded k-space (``A``)."""
return self._forward_impl(coeffs)
# ==================================================================
# implementation
# ==================================================================
def _adjoint_impl(self, kspace_grid: torch.Tensor) -> torch.Tensor:
"""Gridded k-space → subspace coefficients.
Accepted layouts:
- ``(*B, *S, C, *grid_shape)``
- ``(C, *grid_shape)`` (single frame)
Output: ``(*B, *S, K, *image_shape)``.
"""
base = self._base
nat = base.natural_shape # == grid_shape for MaskedFFT
nat_ndim = len(nat)
s_shape = tuple(getattr(base, "stack_shape", ()) or ())
s_ndim = len(s_shape)
expected_trailing = 1 + nat_ndim # (C, *grid_shape)
prefix = tuple(int(s) for s in kspace_grid.shape[:-expected_trailing])
if s_ndim:
if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape:
raise ValueError(
f"kspace_grid prefix {prefix} must end with stack_shape {s_shape}"
)
B_shape = prefix[:-s_ndim]
else:
B_shape = prefix
if not prefix:
return self._adjoint_single(kspace_grid, 0)
B_total = int(np.prod(B_shape)) if B_shape else 1
S_total = int(np.prod(s_shape)) if s_shape else 1
flat = kspace_grid.reshape(
B_total, S_total, *kspace_grid.shape[-expected_trailing:]
)
outs = []
for b in range(B_total):
for s in range(S_total):
outs.append(self._adjoint_single(flat[b, s], s))
stacked = torch.stack(outs, dim=0)
return stacked.reshape(*B_shape, *s_shape, self.K, *base.image_shape)
def _adjoint_single(self, kspace_grid: torch.Tensor, s_flat_idx: int = 0):
"""Single-frame adjoint. Input: ``(C, *grid_shape)``."""
base = self._base
nat = base.natural_shape # == grid_shape
nat_ndim = len(nat)
device = kspace_grid.device
dtype = kspace_grid.dtype
n_coils = int(kspace_grid.shape[0])
if base.smaps is None:
raise NotImplementedError("SubspaceMaskedFFT requires base_op.smaps")
smaps = base.smaps.to(device, dtype=dtype)
basis = self.basis.to(device, dtype=dtype) # (K, T)
K = self.K
phi_shape = [1] * nat_ndim
phi_shape[self._t_axis_in_nat] = self.T
output = torch.zeros(K, *base.image_shape, dtype=dtype, device=device)
for c in range(n_coils):
# kspace_grid[c]: (*grid_shape), expand with K-basis along T axis
kg_c = kspace_grid[c] # (*grid_shape)
# (K, *grid_shape): multiply each basis vector against the grid
weighted = basis.view(K, *phi_shape) * kg_c.unsqueeze(0)
# weighted: (K, *grid_shape) — already on the grid
imgs = base._mask_ifft_crop_batch(weighted, s_flat_idx=s_flat_idx)
output.addcmul_(imgs, smaps[c].conj().unsqueeze(0))
return output
def _forward_impl(self, coeffs: torch.Tensor) -> torch.Tensor:
"""Subspace coefficients → gridded k-space.
Accepted layouts:
- ``(*B, *S, K, *image_shape)``
- ``(K, *image_shape)`` (single frame)
Output: ``(*B, *S, C, *grid_shape)``.
"""
base = self._base
nat = base.natural_shape
s_shape = tuple(getattr(base, "stack_shape", ()) or ())
s_ndim = len(s_shape)
img_ndim = len(base.image_shape)
single_ndim = 1 + img_ndim
prefix = tuple(int(s) for s in coeffs.shape[: coeffs.ndim - single_ndim])
if s_ndim:
if len(prefix) < s_ndim or tuple(prefix[-s_ndim:]) != s_shape:
raise ValueError(
f"coeffs prefix {prefix} must end with stack_shape {s_shape}"
)
B_shape = prefix[:-s_ndim]
else:
B_shape = prefix
if not prefix:
return self._forward_single(coeffs, 0)
B_total = int(np.prod(B_shape)) if B_shape else 1
S_total = int(np.prod(s_shape)) if s_shape else 1
flat = coeffs.reshape(B_total, S_total, *coeffs.shape[-single_ndim:])
outs = []
for b in range(B_total):
for s in range(S_total):
outs.append(self._forward_single(flat[b, s], s))
n_coils = outs[0].shape[0]
stacked = torch.stack(outs, dim=0)
return stacked.reshape(*B_shape, *s_shape, n_coils, *nat)
def _forward_single(self, coeffs: torch.Tensor, s_flat_idx: int = 0):
"""Single-frame forward. Input: ``(K, *image_shape)``, output: ``(C, *grid_shape)``."""
base = self._base
nat = base.natural_shape # == grid_shape
nat_ndim = len(nat)
if coeffs.shape[0] != self.K:
raise ValueError(f"coeffs.shape[0]={coeffs.shape[0]} != K={self.K}")
if tuple(int(s) for s in coeffs.shape[1:]) != tuple(base.image_shape):
raise ValueError(
f"coeffs spatial {tuple(coeffs.shape[1:])} != image_shape {base.image_shape}"
)
device = coeffs.device
dtype = coeffs.dtype
if base.smaps is None:
raise NotImplementedError("SubspaceMaskedFFT requires base_op.smaps")
smaps = base.smaps.to(device, dtype=dtype)
n_coils = int(smaps.shape[0])
basis_conj = self.basis.conj().to(device, dtype=dtype) # (K, T)
K = self.K
phi_shape = [1] * nat_ndim
phi_shape[self._t_axis_in_nat] = self.T
output = torch.empty(n_coils, *nat, dtype=dtype, device=device)
for c in range(n_coils):
coil_imgs = coeffs * smaps[c].unsqueeze(0) # (K, *image_shape)
# FFT + pad + mask → (K, *grid_shape)
kgrids = base._fft_pad_mask_batch(coil_imgs, s_flat_idx=s_flat_idx)
# Accumulate over K: sum_k basis_conj[k, T-dim] * kgrids[k]
# basis_conj: (K, T) reshaped to (K, *phi_shape)
ksp_c = (basis_conj.view(K, *phi_shape) * kgrids).sum(dim=0)
output[c] = ksp_c
return output
@with_torch
def normal(self, coeffs):
"""Normal operator: ``A^H A x``."""
if self.toeplitz:
if self._toep_op is None:
from .._toep._sub_toep import SubspaceToeplitzOp
self._toep_op = SubspaceToeplitzOp(self, device=self._base.device)
return self._toep_op(coeffs)
return self._adjoint_impl(self._forward_impl(coeffs))
def __call__(self, x, adjoint=False):
if adjoint:
return self.adjoint(x)
return self.forward(x)