Toeplitz-Embedded Self-Adjoint A^H A for SparseFFT Gadgets#

Iterative reconstructions (CG, FISTA) repeatedly evaluate the normal operator A^H A. PyGROG ships a Toeplitz-embedded short-circuit: op.normal(image) builds a small PSF once and replaces every forward+adjoint NUFFT pair by a pad → FFT → multiply → IFFT → crop sequence.

The example has two parts:

  1. End-to-end pipeline — a real BrainWeb / spiral / GROG / SparseFFT acquisition, a CG reconstruction using op.normal, and a runtime comparison against the nested forward(adjoint(.)) baseline.

  2. Accuracy panel for the gadgets — small synthetic problems verifying that the Toeplitz versions of ORC and subspace operators match the nested baseline to numerical noise.

The Toeplitz path is the default on CPU and is opt-in on CUDA via toeplitz=True.

import time
import types

import matplotlib.pyplot as plt
import numpy as np
import torch

from brainweb_dl import get_mri

from mrinufft import get_operator, initialize_2D_spiral
from mrinufft.density import voronoi

from pygrog.calib import GrogInterpolator
from pygrog.gadgets import SubspaceGadget
from pygrog.gadgets._off_resonance import OffResonanceSparseFFT
from pygrog.operator import SparseFFT

Part 1: End-to-End SparseFFT with CG Reconstruction#

Demonstrate the Toeplitz-embedded normal() operator for fast \(A^H A\) matrix-vector products, then use it in a CG solver.

image = np.flip(get_mri(0, "T1"), axis=(0, 1, 2))[90].astype(np.float32)
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(

Benchmark: Normal Operator Performance#

Compare the speed of op.normal(x) with Toeplitz acceleration vs. the nested A^H A approach.

# A^H A timing on a random image
torch.manual_seed(0)
x = torch.randn(*shape, dtype=torch.complex64)
t_toep, y_toep = _bench(op_t.normal, x)
t_nest, y_nest = _bench(op_n.normal, x)
err_sparse = _rel_err(y_toep, y_nest)

print(
    f"SparseFFT  Toeplitz {t_toep * 1e3:7.2f} ms  |  nested {t_nest * 1e3:7.2f} ms"
    f"  | speed-up x{t_nest / t_toep:5.2f}  | rel-err {err_sparse:.2e}"
)
SparseFFT  Toeplitz   14.95 ms  |  nested   17.98 ms  | speed-up x 1.20  | rel-err 2.01e-07

CG Reconstruction Using Toeplitz Normal Operator#

t0 = time.perf_counter()
recon_toep = cg(op_t.normal, b, n_iter=12)
t_cg_toep = time.perf_counter() - t0
t0 = time.perf_counter()
recon_nest = cg(op_n.normal, b, n_iter=12)
t_cg_nest = time.perf_counter() - t0

print(
    f"CG (12 iter):  Toeplitz {t_cg_toep:.2f} s  |  nested {t_cg_nest:.2f} s"
    f"  | speed-up x{t_cg_nest / t_cg_toep:.2f}"
)
CG (12 iter):  Toeplitz 0.19 s  |  nested 0.23 s  | speed-up x1.22
recon_t_np = np.abs(recon_toep.detach().cpu().numpy())
CG (Toeplitz) — 0.19 s, CG (nested A^HA) — 0.23 s, Difference

Part 2 — accuracy of the gadget Toeplitz operators#

Small random problems where OffResonanceSparseFFT and SubspaceGadget are used directly to verify that op.normal gives the same result for toeplitz=True and toeplitz=False.

# ---- ORC ----------------------------------------------------------------
grid = (32, 32)
n_samples = 800
L = 4
op_b_t = _make_small_sparse_fft(grid, n_samples, 4, toeplitz=True, seed=0)
op_b_n = _make_small_sparse_fft(grid, n_samples, 4, toeplitz=False, seed=0)
rng = np.random.default_rng(11)
B = (
    rng.standard_normal((n_samples, L)) + 1j * rng.standard_normal((n_samples, L))
).astype(np.complex64)
C = (rng.standard_normal((L, *grid)) + 1j * rng.standard_normal((L, *grid))).astype(
    np.complex64
)
orc_t = OffResonanceSparseFFT(op_b_t, B, C, toeplitz=True)
orc_n = OffResonanceSparseFFT(op_b_n, B, C, toeplitz=False)
x_orc = torch.randn(*grid, dtype=torch.complex64)
t_toep, y_toep = _bench(orc_t.normal, x_orc, n_iter=5)
t_nest, y_nest = _bench(orc_n.normal, x_orc, n_iter=5)
err_orc = _rel_err(y_toep, y_nest)
print(
    f"ORC (L={L})   Toeplitz {t_toep * 1e3:7.2f} ms  |  nested {t_nest * 1e3:7.2f} ms"
    f"  | speed-up x{t_nest / t_toep:5.2f}  | rel-err {err_orc:.2e}"
)

# ---- Subspace -----------------------------------------------------------
T = 12
K = 4
n_pts = 60
n_samples_sub = T * n_pts
rng = np.random.default_rng(13)
indices = rng.integers(0, int(np.prod(grid)), n_samples_sub).astype(np.int64)
weights = rng.random(n_samples_sub).astype(np.float32) + 0.1
smaps_sub = (
    rng.standard_normal((4, *grid)) + 1j * rng.standard_normal((4, *grid))
).astype(np.complex64) * 0.5
indices_t = torch.as_tensor(indices)
weights_t = torch.as_tensor(weights)
sort_perm = torch.argsort(indices_t)
inv_perm = torch.empty_like(sort_perm)
inv_perm[sort_perm] = torch.arange(n_samples_sub)
plan = types.SimpleNamespace(
    grid_shape=grid,
    image_shape=grid,
    grid_size=int(np.prod(grid)),
    indices=indices_t[sort_perm],
    sqrt_weights=weights_t.sqrt()[sort_perm],
    sort_perm=sort_perm,
    inv_perm=inv_perm,
    natural_shape=(T, n_pts),
    n_samples=n_samples_sub,
)
base_sub_t = SparseFFT(plan=plan, smaps=smaps_sub, toeplitz=True)
base_sub_n = SparseFFT(plan=plan, smaps=smaps_sub, toeplitz=False)
basis = (rng.standard_normal((K, T)) + 1j * rng.standard_normal((K, T))).astype(
    np.complex64
)
sub_t = SubspaceGadget(base_sub_t, basis, encoding_axis=-2)
sub_n = SubspaceGadget(base_sub_n, basis, encoding_axis=-2)
x_sub = torch.randn(K, *grid, dtype=torch.complex64)
t_toep, y_toep = _bench(sub_t.normal, x_sub, n_iter=5)
t_nest, y_nest = _bench(sub_n.normal, x_sub, n_iter=5)
err_sub = _rel_err(y_toep, y_nest)
print(
    f"Subspace (K={K}) Toeplitz {t_toep * 1e3:7.2f} ms  |  nested {t_nest * 1e3:7.2f} ms"
    f"  | speed-up x{t_nest / t_toep:5.2f}  | rel-err {err_sub:.2e}"
)
ORC (L=4)   Toeplitz    0.57 ms  |  nested    0.95 ms  | speed-up x 1.66  | rel-err 1.42e-07
Subspace (K=4) Toeplitz    0.55 ms  |  nested    4.60 ms  | speed-up x 8.44  | rel-err 1.78e-07

Accuracy summary#

labels = ["SparseFFT", "OffResonance", "Subspace"]
errs = [err_sparse, err_orc, err_sub]
fig, ax = plt.subplots(figsize=(6, 4))
Toeplitz `op.normal` accuracy

Total running time of the script: (0 minutes 2.162 seconds)

Gallery generated by Sphinx-Gallery