TorchKbNufft Documentation

Documentation | GitHub | Notebook Examples

About

torchkbnufft implements a non-uniform Fast Fourier Transform [1, 2] with Kaiser-Bessel gridding in PyTorch. The implementation is completely in Python, facilitating flexible deployment in readable code with no compilation. NUFFT functions are each wrapped as a torch.autograd.Function, allowing backpropagation through NUFFT operators for training neural networks.

This package was inspired in large part by the NUFFT implementation in the Michigan Image Reconstruction Toolbox (Matlab).

Installation

Simple installation can be done via PyPI:

pip install torchkbnufft

torchkbnufft only requires numpy, scipy, and torch as dependencies.

Operation Modes and Stages

The package has three major classes of NUFFT operation mode: table-based NUFFT interpolation, sparse matrix-based NUFFT interpolation, and forward/backward operators with Toeplitz-embedded FFTs [3]. Table interpolation is the standard operation mode, whereas the Toeplitz method is always the fastest for forward/backward NUFFTs. For some problems, sparse matrices may be fast. It is generally best to start with Table interpolation and then experiment with the other modes for your problem.

Sensitivity maps can be incorporated by passing them into a KbNufft or KbNufftAdjoint object. Auxiliary functions for calculating sparse interpolation matrices, density compensation functions, and Toeplitz filter kernels are also included.

For examples, see Basic Usage.

References

  1. Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation. IEEE Transactions on Signal Processing, 51(2), 560-574.

  2. Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). Rapid gridding reconstruction with a minimal oversampling ratio. IEEE Transactions on Medical Imaging, 24(6), 799-808.

  3. Feichtinger, H. G., Gr, K., & Strohmer, T. (1995). Efficient numerical methods in non-uniform sampling theory. Numerische Mathematik, 69(4), 423-440.

Basic Usage

torchkbnufft works primarily via PyTorch modules. You create a module with the properties of your imaging setup. The module will calculate a Kaiser-Bessel kernel and some interpolation parameters based on your inputs. Then, you apply the module to your data stored as PyTorch tensors. NUFFT operations are wrapped in torch.autograd.Function classes for backpropagation and training neural networks.

The following code loads a Shepp-Logan phantom and computes a single radial spoke of k-space data:

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom

x = shepp_logan_phantom().astype(np.complex)
im_size = x.shape
# convert to tensor, unsqueeze batch and coil dimension
# output size: (1, 1, ny, nx)
x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64)

klength = 64
ktraj = np.stack(
   (np.zeros(64), np.linspace(-np.pi, np.pi, klength))
)
# convert to tensor, unsqueeze batch dimension
# output size: (2, klength)
ktraj = torch.tensor(ktraj).to(torch.float)

nufft_ob = tkbn.KbNufft(im_size=im_size)
# outputs a (1, 1, klength) vector of k-space data
kdata = nufft_ob(x, ktraj)

The package also includes utilities for working with SENSE-NUFFT operators. The above code can be modified to include sensitivity maps.

smaps = torch.rand(1, 8, 400, 400) + 1j * torch.rand(1, 8, 400, 400)
sense_data = nufft_ob(x, ktraj, smaps=smaps.to(x))

This code first multiplies by the sensitivity coils in smaps, then computes a 64-length radial spoke for each coil. All operations are broadcast across coils, which minimizes interaction with the Python interpreter, helping computation speed.

Sparse matrices are an alternative to table interpolation. Their speed can vary, but they are a bit more accurate than standard table mode. The following code calculates sparse interpolation matrices and uses them to compute a single radial spoke of k-space data:

adjnufft_ob = tkbn.KbNufftAdjoint(im_size=im_size)

# precompute the sparse interpolation matrices
interp_mats = tkbn.calc_tensor_spmatrix(
   ktraj,
   im_size=im_size
)

# use sparse matrices in adjoint
image = adjnufft_ob(kdata, ktraj, interp_mats)

Sparse matrix multiplication is only implemented for real numbers in PyTorch, which can limit their speed.

The package includes routines for calculating embedded Toeplitz kernels and using them as FFT filters for the forward/backward NUFFT operations. This is very useful for gradient descent algorithms that must use the forward/backward ops in calculating the gradient. The following code shows an example:

toep_ob = tkbn.ToepNufft()

# precompute the embedded Toeplitz FFT kernel
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size)

# use FFT kernel from embedded Toeplitz matrix
image = toep_ob(image, kernel)

All of the examples included in this repository can be run on the GPU by sending the NUFFT object and data to the GPU prior to the function call, e.g.,

adjnufft_ob = adjnufft_ob.to(torch.device('cuda'))
kdata = kdata.to(torch.device('cuda'))
ktraj = ktraj.to(torch.device('cuda'))

image = adjnufft_ob(kdata, ktraj)

Similar to programming low-level code, PyTorch will throw errors if the underlying dtype and device of all objects are not matching. Be sure to make sure your data and NUFFT objects are on the right device and in the right format to avoid these errors.

For more details, please examine the API in torchkbnufft or the notebooks in the GitHub repository.

Performance Tips

torchkbnufft is primarily written for the goal of scaling parallelism within the PyTorch framework. The performance bottleneck of the package comes from two sources: 1) advanced indexing and 2) multiplications. Multiplications are handled in a way that scales well, but advanced indexing is not due to limitations with PyTorch. As a result, growth in problem size that is independent of the indexing bottleneck is handled very well by the package, such as:

  1. Scaling the batch dimension.

  2. Scaling the coil dimension.

Generally, you can just add to these dimensions and the package will perform well without adding much compute time. If you’re chasing more speed, some strategies that might be helpful are listed below.

Using Batched K-space Trajectories

As of version 1.1.0, torchkbnufft can use batched k-space trajectories. If you pass in a variable for omega with dimensions (N, length(im_size), klength), the package will parallelize the execution of all trajectories in the N dimension. This is useful when N is very large, as might occur in dynamic imaging settings. The following shows an example:

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom

batch_size = 12

x = shepp_logan_phantom().astype(np.complex)
im_size = x.shape
# convert to tensor, unsqueeze batch and coil dimension
# output size: (batch_size, 1, ny, nx)
x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64)
x = x.repeat(batch_size, 1, 1, 1)

klength = 64
ktraj = np.stack(
    (np.zeros(64), np.linspace(-np.pi, np.pi, klength))
)
# convert to tensor, unsqueeze batch dimension
# output size: (batch_size, 2, klength)
ktraj = torch.tensor(ktraj).to(torch.float)
ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1)

nufft_ob = tkbn.KbNufft(im_size=im_size)
# outputs a (batch_size, 1, klength) vector of k-space data
kdata = nufft_ob(x, ktraj)

This code will then compute the 12 different radial spokes while parallelizing as much as possible.

Lowering the Precision

A simple way to save both memory and compute time is to decrease the precision. PyTorch normally operates at a default 32-bit floating point precision, but if you’re converting data from NumPy then you might have some data at 64-bit floating precision. To use 32-bit precision, simply do the following:

image = image.to(dtype=torch.complex64)
ktraj = ktraj.to(dtype=torch.float32)
forw_ob = forw_ob.to(image)

data = forw_ob(image, ktraj)

The forw_ob.to(image) command will automagically determine the type for both real and complex tensors registered as buffers under forw_ob, so you should be able to do this safely in your code.

In many cases, the tradeoff for going from 64-bit to 32-bit is not severe, so you can securely use 32-bit precision.

Lowering the Oversampling Ratio

If you create a KbNufft object using the following code:

forw_ob = tkbn.KbNufft(im_size=im_size)

then by default it will use a 2-factor oversampled grid. For some applications, this can be overkill. If you can sacrifice some accuracy for your application, you can use a smaller grid with 1.25-factor oversampling by altering how you initialize NUFFT objects like KbNufft:

grid_size = tuple([int(el * 1.25) for el in im_size])
forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size)

Using Fewer Interpolation Neighbors

Another major speed factor is how many neighbors you use for interpolation. By default, torchkbnufft uses 6 nearest neighbors in each dimension. If you can sacrifice accuracy, you can get more speed by using fewer neighbors by altering how you initialize NUFFT objects like KbNufft:

forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=4)

If you know that you can be less accurate in one dimension (e.g., the z-dimension), then you can use less neighbors in only that dimension:

forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=(4, 6, 6))

Package Limitations

As mentioned earlier, batches and coils scale well, primarily due to the fact that they don’t impact the bottlenecks of the package around advanced indexing. Where torchkbnufft does not scale well is:

  1. Very long k-space trajectories.

  2. More imaging dimensions (e.g., 3D).

For these settings, you can first try to use some of the strategies here (lowering precision, fewer neighbors, smaller grid). In some cases, lowering the precision a bit and using a GPU can still give strong performance. If you’re still waiting too long for compute after trying all of these, you may be running into the limits of the package.

torchkbnufft

NUFFT Modules

These are the primary workhorse modules for applying NUFFT operations.

KbInterp

Non-uniform Kaiser-Bessel interpolation layer.

KbInterpAdjoint

Non-uniform Kaiser-Bessel interpolation adjoint layer.

KbNufft

Non-uniform FFT layer.

KbNufftAdjoint

Non-uniform FFT adjoint layer.

ToepNufft

Forward/backward NUFFT with Toeplitz embedding.

Utility Functions

Functions for calculating density compensation and Toeplitz kernels.

calc_density_compensation_function

Numerical density compensation estimation.

calc_tensor_spmatrix

Builds a sparse matrix for interpolation.

calc_toeplitz_kernel

Calculates an FFT kernel for Toeplitz embedding.

Math Functions

Complex mathematical operations (gradually being removed as of PyTorch 1.7).

absolute

Complex absolute value.

complex_mult

Complex multiplication.

complex_sign

Complex sign function value.

conj_complex_mult

Complex multiplication, conjugating second input.

imag_exp

Imaginary exponential.

inner_product

Complex inner product.

torchkbnufft.functional

fft_filter

fft_filter(image, kernel, norm='ortho')[source]

FFT-based filtering on an oversampled grid.

This is a wrapper for the operation

\[\text{output} = iFFT(\text{kernel}*FFT(\text{image})) \]

where \(iFFT\) and \(FFT\) are both implemented as oversampled FFTs.

Parameters
  • image (Tensor) – The image to be filtered.

  • kernel (Tensor) – FFT-domain filter.

  • norm (Optional[str]) – Whether to apply normalization with the FFT operation. Options are "ortho" or None.

Return type

Tensor

Returns

Filtered version of image.

Interpolation Functions

kb_spmat_interp

kb_spmat_interp(image, interp_mats)[source]

Kaiser-Bessel sparse matrix interpolation.

See KbInterp for an overall description of interpolation.

To calculate the sparse matrix tuple, see calc_tensor_spmatrix().

Parameters
  • image (Tensor) – Gridded data to be interpolated to scattered data.

  • interp_mats (Tuple[Tensor, Tensor]) – 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation.

Return type

Tensor

Returns

image calculated at scattered locations.

kb_spmat_interp_adjoint

kb_spmat_interp_adjoint(data, interp_mats, grid_size)[source]

Kaiser-Bessel sparse matrix interpolation adjoint.

See KbInterpAdjoint for an overall description of adjoint interpolation.

To calculate the sparse matrix tuple, see calc_tensor_spmatrix().

Parameters
  • data (Tensor) – Scattered data to be interpolated to gridded data.

  • interp_mats (Tuple[Tensor, Tensor]) – 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation.

Return type

Tensor

Returns

data calculated at gridded locations.

kb_table_interp

kb_table_interp(image, omega, tables, n_shift, numpoints, table_oversamp, offsets)[source]

Kaiser-Bessel table interpolation.

See KbInterp for an overall description of interpolation and how to construct the function arguments.

Parameters
  • image (Tensor) – Gridded data to be interpolated to scattered data.

  • omega (Tensor) – k-space trajectory (in radians/voxel).

  • tables (List[Tensor]) – Interpolation tables (one table for each dimension).

  • n_shift (Tensor) – Size for fftshift, usually im_size // 2.

  • numpoints (Tensor) – Number of neighbors to use for interpolation.

  • table_oversamp (Tensor) – Table oversampling factor.

  • offsets (Tensor) – A list of offsets, looping over all possible combinations of numpoints.

Return type

Tensor

Returns

image calculated at scattered locations.

kb_table_interp_adjoint

kb_table_interp_adjoint(data, omega, tables, n_shift, numpoints, table_oversamp, offsets, grid_size)[source]

Kaiser-Bessel table interpolation adjoint.

See KbInterpAdjoint for an overall description of adjoint interpolation.

Parameters
  • data (Tensor) – Scattered data to be interpolated to gridded data.

  • omega (Tensor) – k-space trajectory (in radians/voxel).

  • tables (List[Tensor]) – Interpolation tables (one table for each dimension).

  • n_shift (Tensor) – Size for fftshift, usually im_size // 2.

  • numpoints (Tensor) – Number of neighbors to use for interpolation.

  • table_oversamp (Tensor) – Table oversampling factor.

  • offsets (Tensor) – A list of offsets, looping over all possible combinations of numpoints.

  • grid_size (Tensor) – Size of grid to use for interpolation, typically 1.25 to 2 times im_size.

Return type

Tensor

Returns

data calculated at gridded locations.

NUFFT Functions

kb_spmat_nufft

kb_spmat_nufft(image, scaling_coef, im_size, grid_size, interp_mats, norm=None)[source]

Kaiser-Bessel NUFFT with sparse matrix interpolation.

See KbNufft for an overall description of the forward NUFFT.

To calculate the sparse matrix tuple, see calc_tensor_spmatrix().

Parameters
  • image (Tensor) – Image to be NUFFT’d to scattered data.

  • scaling_coef (Tensor) – Image-domain coefficients to pre-compensate for interpolation errors.

  • im_size (Tensor) – Size of image with length being the number of dimensions.

  • grid_size (Tensor) – Size of grid to use for interpolation, typically 1.25 to 2 times im_size.

  • interp_mats (Tuple[Tensor, Tensor]) – 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation.

  • norm (Optional[str]) – Whether to apply normalization with the FFT operation. Options are "ortho" or None.

Return type

Tensor

Returns

image calculated at scattered Fourier locations.

kb_spmat_nufft_adjoint

kb_spmat_nufft_adjoint(data, scaling_coef, im_size, grid_size, interp_mats, norm=None)[source]

Kaiser-Bessel adjoint NUFFT with sparse matrix interpolation.

See KbNufftAdjoint for an overall description of the adjoint NUFFT.

To calculate the sparse matrix tuple, see calc_tensor_spmatrix().

Parameters
  • data (Tensor) – Scattered data to be iNUFFT’d to an image.

  • scaling_coef (Tensor) – Image-domain coefficients to compensate for interpolation errors.

  • im_size (Tensor) – Size of image with length being the number of dimensions.

  • grid_size (Tensor) – Size of grid to use for interpolation, typically 1.25 to 2 times im_size.

  • interp_mats (Tuple[Tensor, Tensor]) – 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation.

  • norm (Optional[str]) – Whether to apply normalization with the FFT operation. Options are "ortho" or None.

Return type

Tensor

Returns

data transformed to an image.

kb_table_nufft

kb_table_nufft(image, scaling_coef, im_size, grid_size, omega, tables, n_shift, numpoints, table_oversamp, offsets, norm=None)[source]

Kaiser-Bessel NUFFT with table interpolation.

See KbNufft for an overall description of the forward NUFFT.

Parameters
  • image (Tensor) – Image to be NUFFT’d to scattered data.

  • scaling_coef (Tensor) – Image-domain coefficients to pre-compensate for interpolation errors.

  • im_size (Tensor) – Size of image with length being the number of dimensions.

  • grid_size (Tensor) – Size of grid to use for interpolation, typically 1.25 to 2 times im_size.

  • omega (Tensor) – k-space trajectory (in radians/voxel).

  • tables (List[Tensor]) – Interpolation tables (one table for each dimension).

  • n_shift (Tensor) – Size for fftshift, usually im_size // 2.

  • numpoints (Tensor) – Number of neighbors to use for interpolation.

  • table_oversamp (Tensor) – Table oversampling factor.

  • offsets (Tensor) – A list of offsets, looping over all possible combinations of numpoints.

  • norm (Optional[str]) – Whether to apply normalization with the FFT operation. Options are "ortho" or None.

Return type

Tensor

Returns

image calculated at scattered Fourier locations.

kb_table_nufft_adjoint

kb_table_nufft_adjoint(data, scaling_coef, im_size, grid_size, omega, tables, n_shift, numpoints, table_oversamp, offsets, norm=None)[source]

Kaiser-Bessel NUFFT adjoint with table interpolation.

See KbNufftAdjoint for an overall description of the adjoint NUFFT.

Parameters
  • data (Tensor) – Scattered data to be iNUFFT’d to an image.

  • scaling_coef (Tensor) – Image-domain coefficients to compensate for interpolation errors.

  • im_size (Tensor) – Size of image with length being the number of dimensions.

  • grid_size (Tensor) – Size of grid to use for interpolation, typically 1.25 to 2 times im_size.

  • omega (Tensor) – k-space trajectory (in radians/voxel).

  • tables (List[Tensor]) – Interpolation tables (one table for each dimension).

  • n_shift (Tensor) – Size for fftshift, usually im_size // 2.

  • numpoints (Tensor) – Number of neighbors to use for interpolation.

  • table_oversamp (Tensor) – Table oversampling factor.

  • offsets (Tensor) – A list of offsets, looping over all possible combinations of numpoints.

  • norm (Optional[str]) – Whether to apply normalization with the FFT operation. Options are "ortho" or None.

Return type

Tensor

Returns

data transformed to an image.

Indices and tables