Batched and Stacked Reconstructions with PyGROG#

This example focuses on stack/batch behavior and trajectory handling.

We demonstrate three cases:

  1. Same trajectory for all images in a stack (vectorized in one call).

  2. Different trajectory for each image in the stack (one plan per image).

  3. Two stack axes: one shared trajectory axis and one per-image trajectory axis.

import matplotlib.pyplot as plt
import numpy as np

from brainweb_dl import get_mri

from mrinufft import get_operator, initialize_2D_spiral
from mrinufft.density import voronoi

from pygrog.calib import GrogInterpolator
from pygrog.operator import SparseFFT
vol = get_mri(0, "T1")

Case 1: Shared trajectory across one stack axis#

When all images in a stack use the same trajectory, we can initialize a single GrogInterpolator and call interpolate() once with batch dimension.

B = 3
images_case1 = vol[88 : 88 + B]

density_base = voronoi(samples_base)
nufft_shared = get_operator("finufft")(
    samples=samples_base,
    shape=shape,
    n_coils=n_coils,
    smaps=smaps,
    density=density_base,
    squeeze_dims=True,
)

ksp_case1 = np.stack(
    [nufft_shared.op(images_case1[b].astype(np.complex64)) for b in range(B)],
    axis=0,
)  # (B, C, n_samples)
ksp_case1 = ksp_case1.reshape(B, n_coils, n_shots, n_read)


coords_base = (samples_base * np.asarray(shape, dtype=np.float32)).astype(np.float32)
grog_shared = GrogInterpolator(
    shape=shape,
    coords=coords_base,
    kernel_width=2,
    oversamp=1.25,
    image_shape=shape,
)
grog_shared.calc_interp_table(calib_cart, lamda=0.01, precision=1)

sparse_case1 = grog_shared.interpolate(ksp_case1, ret_image=False)
sqrt_w = np.asarray(grog_shared.plan.pre_weights)
sparse_case1_w = sparse_case1 * sqrt_w

op_shared = SparseFFT(plan=grog_shared.plan, smaps=smaps)
recon_case1 = np.abs(op_shared.adjoint(sparse_case1_w))

print(f"Case 1 (shared trajectory): {recon_case1.shape}")
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
Case 1 (shared trajectory): (3, 217, 181)
fig, axes = plt.subplots(2, B, figsize=(4 * B, 6), constrained_layout=True)
GT #1, GT #2, GT #3, Shared traj recon #1, Shared traj recon #2, Shared traj recon #3

Case 2: Different trajectory per image in one stack axis#

Different trajectory per image is handled in one stacked call. coords_case2 carries a non-singleton stack axis (B), so GROG builds one stacked plan and applies all trajectory variants in a single interpolate().

images_case2 = vol[92 : 92 + B]
angles = np.linspace(0.0, np.pi / 6.0, B, dtype=np.float32)


grog_case2 = GrogInterpolator(
    shape=shape,
    coords=coords_case2,
    kernel_width=2,
    oversamp=1.25,
    image_shape=shape,
)
grog_case2.calc_interp_table(calib_cart, lamda=0.01, precision=1)

sparse_case2 = grog_case2.interpolate(ksp_case2, ret_image=False)
pre_w_case2 = np.asarray(grog_case2.plan.pre_weights)[:, np.newaxis, :]
sparse_case2_w = sparse_case2 * pre_w_case2

op_case2 = SparseFFT(plan=grog_case2.plan, smaps=smaps)
recon_case2 = np.abs(op_case2.adjoint(sparse_case2_w))
print(f"Case 2 (per-image trajectory): {recon_case2.shape}")
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
Case 2 (per-image trajectory): (3, 217, 181)
fig, axes = plt.subplots(2, B, figsize=(4 * B, 6), constrained_layout=True)
GT #1, GT #2, GT #3, Per-image traj recon #1, Per-image traj recon #2, Per-image traj recon #3

Case 3: Two stack axes (T, B)#

Demonstrate 2D stacking: T (time/temporal) and B (trajectory variants). coords_case3 is stacked along B only, while data carries (T, B). T is treated as batch and B as stacked trajectories in a single call.

T = 2
B2 = 3
images_case3 = vol[96 : 96 + T * B2].reshape(T, B2, *shape)
angles_b = np.linspace(0.0, np.pi / 5.0, B2, dtype=np.float32)


grog_case3 = GrogInterpolator(
    shape=shape,
    coords=coords_case3,
    kernel_width=2,
    oversamp=1.25,
    image_shape=shape,
)
grog_case3.calc_interp_table(calib_cart, lamda=0.01, precision=1)

sparse_case3 = grog_case3.interpolate(ksp_case3, ret_image=False)
pre_w_case3 = np.asarray(grog_case3.plan.pre_weights)[np.newaxis, :, np.newaxis, :]
sparse_case3_w = sparse_case3 * pre_w_case3

op_case3 = SparseFFT(plan=grog_case3.plan, smaps=smaps)
recon_case3 = np.abs(op_case3.adjoint(sparse_case3_w))

print(f"Case 3 (2D stack TxB): {recon_case3.shape}")
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/pygrog/envs/latest/lib/python3.12/site-packages/mrinufft/_utils.py:67: UserWarning: Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)
  warnings.warn(
Case 3 (2D stack TxB): (2, 3, 217, 181)
fig, axes = plt.subplots(T, B2, figsize=(4 * B2, 3.5 * T), constrained_layout=True)
T=0, B=0, T=0, B=1, T=0, B=2, T=1, B=0, T=1, B=1, T=1, B=2

Total running time of the script: (0 minutes 4.690 seconds)

Gallery generated by Sphinx-Gallery