TorchKbNufft Documentation
Documentation | GitHub | Notebook Examples
About
torchkbnufft
implements a non-uniform Fast Fourier Transform
[1,
2] with
Kaiser-Bessel gridding in PyTorch. The implementation is completely in Python,
facilitating flexible deployment in readable code with no compilation. NUFFT
functions are each wrapped as a torch.autograd.Function
, allowing
backpropagation through NUFFT operators for training neural networks.
This package was inspired in large part by the NUFFT implementation in the Michigan Image Reconstruction Toolbox (Matlab).
Installation
Simple installation can be done via PyPI:
pip install torchkbnufft
torchkbnufft
only requires numpy
, scipy
, and
torch
as dependencies.
Operation Modes and Stages
The package has three major classes of NUFFT operation mode: table-based NUFFT interpolation, sparse matrix-based NUFFT interpolation, and forward/backward operators with Toeplitz-embedded FFTs [3]. Table interpolation is the standard operation mode, whereas the Toeplitz method is always the fastest for forward/backward NUFFTs. For some problems, sparse matrices may be fast. It is generally best to start with Table interpolation and then experiment with the other modes for your problem.
Sensitivity maps can be incorporated by passing them into a
KbNufft
or KbNufftAdjoint
object. Auxiliary functions for calculating sparse interpolation matrices,
density compensation functions, and Toeplitz filter kernels are also included.
For examples, see Basic Usage.
References
Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation. IEEE Transactions on Signal Processing, 51(2), 560-574.
Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). Rapid gridding reconstruction with a minimal oversampling ratio. IEEE Transactions on Medical Imaging, 24(6), 799-808.
Feichtinger, H. G., Gr, K., & Strohmer, T. (1995). Efficient numerical methods in non-uniform sampling theory. Numerische Mathematik, 69(4), 423-440.
Basic Usage
torchkbnufft
works primarily via PyTorch modules. You create a module with
the properties of your imaging setup. The module will calculate a Kaiser-Bessel
kernel and some interpolation parameters based on your inputs. Then, you apply
the module to your data stored as PyTorch tensors. NUFFT operations are wrapped in
torch.autograd.Function
classes for backpropagation and training
neural networks.
The following code loads a Shepp-Logan phantom and computes a single radial spoke of k-space data:
import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
x = shepp_logan_phantom().astype(np.complex)
im_size = x.shape
# convert to tensor, unsqueeze batch and coil dimension
# output size: (1, 1, ny, nx)
x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64)
klength = 64
ktraj = np.stack(
(np.zeros(64), np.linspace(-np.pi, np.pi, klength))
)
# convert to tensor, unsqueeze batch dimension
# output size: (2, klength)
ktraj = torch.tensor(ktraj).to(torch.float)
nufft_ob = tkbn.KbNufft(im_size=im_size)
# outputs a (1, 1, klength) vector of k-space data
kdata = nufft_ob(x, ktraj)
The package also includes utilities for working with SENSE-NUFFT operators. The above code can be modified to include sensitivity maps.
smaps = torch.rand(1, 8, 400, 400) + 1j * torch.rand(1, 8, 400, 400)
sense_data = nufft_ob(x, ktraj, smaps=smaps.to(x))
This code first multiplies by the sensitivity coils in smaps
, then
computes a 64-length radial spoke for each coil. All operations are broadcast
across coils, which minimizes interaction with the Python interpreter, helping
computation speed.
Sparse matrices are an alternative to table interpolation. Their speed can vary, but they are a bit more accurate than standard table mode. The following code calculates sparse interpolation matrices and uses them to compute a single radial spoke of k-space data:
adjnufft_ob = tkbn.KbNufftAdjoint(im_size=im_size)
# precompute the sparse interpolation matrices
interp_mats = tkbn.calc_tensor_spmatrix(
ktraj,
im_size=im_size
)
# use sparse matrices in adjoint
image = adjnufft_ob(kdata, ktraj, interp_mats)
Sparse matrix multiplication is only implemented for real numbers in PyTorch, which can limit their speed.
The package includes routines for calculating embedded Toeplitz kernels and using them as FFT filters for the forward/backward NUFFT operations. This is very useful for gradient descent algorithms that must use the forward/backward ops in calculating the gradient. The following code shows an example:
toep_ob = tkbn.ToepNufft()
# precompute the embedded Toeplitz FFT kernel
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size)
# use FFT kernel from embedded Toeplitz matrix
image = toep_ob(image, kernel)
All of the examples included in this repository can be run on the GPU by sending the NUFFT object and data to the GPU prior to the function call, e.g.,
adjnufft_ob = adjnufft_ob.to(torch.device('cuda'))
kdata = kdata.to(torch.device('cuda'))
ktraj = ktraj.to(torch.device('cuda'))
image = adjnufft_ob(kdata, ktraj)
Similar to programming low-level code, PyTorch will throw errors if the
underlying dtype
and device
of all objects are not matching. Be
sure to make sure your data and NUFFT objects are on the right device and in
the right format to avoid these errors.
For more details, please examine the API in torchkbnufft or check out the notebooks below on Google Colab.
Performance Tips
torchkbnufft
is primarily written for the goal of scaling parallelism within
the PyTorch framework. The performance bottleneck of the package comes from two sources:
1) advanced indexing and 2) multiplications. Multiplications are handled in a way that
scales well, but advanced indexing is not due to
limitations with PyTorch.
As a result, growth in problem size that is independent of the indexing bottleneck is
handled very well by the package, such as:
Scaling the batch dimension.
Scaling the coil dimension.
Generally, you can just add to these dimensions and the package will perform well without adding much compute time. If you’re chasing more speed, some strategies that might be helpful are listed below.
Using Batched K-space Trajectories
As of version 1.1.0
, torchkbnufft
can use batched k-space trajectories.
If you pass in a variable for omega
with dimensions
(N, length(im_size), klength)
, the package will parallelize the execution of all
trajectories in the N
dimension. This is useful when N
is very large, as might
occur in dynamic imaging settings. The following shows an example:
import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
batch_size = 12
x = shepp_logan_phantom().astype(np.complex)
im_size = x.shape
# convert to tensor, unsqueeze batch and coil dimension
# output size: (batch_size, 1, ny, nx)
x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64)
x = x.repeat(batch_size, 1, 1, 1)
klength = 64
ktraj = np.stack(
(np.zeros(64), np.linspace(-np.pi, np.pi, klength))
)
# convert to tensor, unsqueeze batch dimension
# output size: (batch_size, 2, klength)
ktraj = torch.tensor(ktraj).to(torch.float)
ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1)
nufft_ob = tkbn.KbNufft(im_size=im_size)
# outputs a (batch_size, 1, klength) vector of k-space data
kdata = nufft_ob(x, ktraj)
This code will then compute the 12 different radial spokes while parallelizing as much as possible.
Lowering the Precision
A simple way to save both memory and compute time is to decrease the precision. PyTorch normally operates at a default 32-bit floating point precision, but if you’re converting data from NumPy then you might have some data at 64-bit floating precision. To use 32-bit precision, simply do the following:
image = image.to(dtype=torch.complex64)
ktraj = ktraj.to(dtype=torch.float32)
forw_ob = forw_ob.to(image)
data = forw_ob(image, ktraj)
The forw_ob.to(image)
command will automagically determine the type for both real
and complex tensors registered as buffers under forw_ob
, so you should be able to
do this safely in your code.
In many cases, the tradeoff for going from 64-bit to 32-bit is not severe, so you can securely use 32-bit precision.
Lowering the Oversampling Ratio
If you create a KbNufft
object using the following code:
forw_ob = tkbn.KbNufft(im_size=im_size)
then by default it will use a 2-factor oversampled grid. For some applications, this can
be overkill. If you can sacrifice some accuracy for your application, you can use a
smaller grid with 1.25-factor oversampling by altering how you initialize NUFFT objects
like KbNufft
:
grid_size = tuple([int(el * 1.25) for el in im_size])
forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size)
Using Fewer Interpolation Neighbors
Another major speed factor is how many neighbors you use for interpolation. By default,
torchkbnufft
uses 6 nearest neighbors in each dimension. If you can sacrifice
accuracy, you can get more speed by using fewer neighbors by altering how you initialize
NUFFT objects like KbNufft
:
forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=4)
If you know that you can be less accurate in one dimension (e.g., the z-dimension), then you can use less neighbors in only that dimension:
forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=(4, 6, 6))
Package Limitations
As mentioned earlier, batches and coils scale well, primarily due to the fact that they
don’t impact the bottlenecks of the package around advanced indexing. Where
torchkbnufft
does not scale well is:
Very long k-space trajectories.
More imaging dimensions (e.g., 3D).
For these settings, you can first try to use some of the strategies here (lowering precision, fewer neighbors, smaller grid). In some cases, lowering the precision a bit and using a GPU can still give strong performance. If you’re still waiting too long for compute after trying all of these, you may be running into the limits of the package.
torchkbnufft
NUFFT Modules
These are the primary workhorse modules for applying NUFFT operations.
Non-uniform Kaiser-Bessel interpolation layer. |
|
Non-uniform Kaiser-Bessel interpolation adjoint layer. |
|
Non-uniform FFT layer. |
|
Non-uniform FFT adjoint layer. |
|
Forward/backward NUFFT with Toeplitz embedding. |
Utility Functions
Functions for calculating density compensation and Toeplitz kernels.
Numerical density compensation estimation. |
|
Builds a sparse matrix for interpolation. |
|
Calculates an FFT kernel for Toeplitz embedding. |
Math Functions
Complex mathematical operations (gradually being removed as of PyTorch 1.7).
Complex absolute value. |
|
Complex multiplication. |
|
Complex sign function value. |
|
Complex multiplication, conjugating second input. |
|
Imaginary exponential. |
|
Complex inner product. |
torchkbnufft.functional
fft_filter
- fft_filter(image, kernel, norm='ortho')[source]
FFT-based filtering on an oversampled grid.
This is a wrapper for the operation
\[\text{output} = iFFT(\text{kernel}*FFT(\text{image})) \]where \(iFFT\) and \(FFT\) are both implemented as oversampled FFTs.
Interpolation Functions
kb_spmat_interp
- kb_spmat_interp(image, interp_mats)[source]
Kaiser-Bessel sparse matrix interpolation.
See
KbInterp
for an overall description of interpolation.To calculate the sparse matrix tuple, see
calc_tensor_spmatrix()
.
kb_spmat_interp_adjoint
- kb_spmat_interp_adjoint(data, interp_mats, grid_size)[source]
Kaiser-Bessel sparse matrix interpolation adjoint.
See
KbInterpAdjoint
for an overall description of adjoint interpolation.To calculate the sparse matrix tuple, see
calc_tensor_spmatrix()
.
kb_table_interp
- kb_table_interp(image, omega, tables, n_shift, numpoints, table_oversamp, offsets)[source]
Kaiser-Bessel table interpolation.
See
KbInterp
for an overall description of interpolation and how to construct the function arguments.- Parameters
image (
Tensor
) – Gridded data to be interpolated to scattered data.omega (
Tensor
) – k-space trajectory (in radians/voxel).tables (
List
[Tensor
]) – Interpolation tables (one table for each dimension).n_shift (
Tensor
) – Size for fftshift, usuallyim_size // 2
.numpoints (
Tensor
) – Number of neighbors to use for interpolation.table_oversamp (
Tensor
) – Table oversampling factor.offsets (
Tensor
) – A list of offsets, looping over all possible combinations ofnumpoints
.
- Return type
- Returns
image
calculated at scattered locations.
kb_table_interp_adjoint
- kb_table_interp_adjoint(data, omega, tables, n_shift, numpoints, table_oversamp, offsets, grid_size)[source]
Kaiser-Bessel table interpolation adjoint.
See
KbInterpAdjoint
for an overall description of adjoint interpolation.- Parameters
data (
Tensor
) – Scattered data to be interpolated to gridded data.omega (
Tensor
) – k-space trajectory (in radians/voxel).tables (
List
[Tensor
]) – Interpolation tables (one table for each dimension).n_shift (
Tensor
) – Size for fftshift, usuallyim_size // 2
.numpoints (
Tensor
) – Number of neighbors to use for interpolation.table_oversamp (
Tensor
) – Table oversampling factor.offsets (
Tensor
) – A list of offsets, looping over all possible combinations ofnumpoints
.grid_size (
Tensor
) – Size of grid to use for interpolation, typically 1.25 to 2 timesim_size
.
- Return type
- Returns
data
calculated at gridded locations.
NUFFT Functions
kb_spmat_nufft
- kb_spmat_nufft(image, scaling_coef, im_size, grid_size, interp_mats, norm=None)[source]
Kaiser-Bessel NUFFT with sparse matrix interpolation.
See
KbNufft
for an overall description of the forward NUFFT.To calculate the sparse matrix tuple, see
calc_tensor_spmatrix()
.- Parameters
image (
Tensor
) – Image to be NUFFT’d to scattered data.scaling_coef (
Tensor
) – Image-domain coefficients to pre-compensate for interpolation errors.im_size (
Tensor
) – Size of image with length being the number of dimensions.grid_size (
Tensor
) – Size of grid to use for interpolation, typically 1.25 to 2 timesim_size
.interp_mats (
Tuple
[Tensor
,Tensor
]) – 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation.norm (
Optional
[str
]) – Whether to apply normalization with the FFT operation. Options are"ortho"
orNone
.
- Return type
- Returns
image
calculated at scattered Fourier locations.
kb_spmat_nufft_adjoint
- kb_spmat_nufft_adjoint(data, scaling_coef, im_size, grid_size, interp_mats, norm=None)[source]
Kaiser-Bessel adjoint NUFFT with sparse matrix interpolation.
See
KbNufftAdjoint
for an overall description of the adjoint NUFFT.To calculate the sparse matrix tuple, see
calc_tensor_spmatrix()
.- Parameters
data (
Tensor
) – Scattered data to be iNUFFT’d to an image.scaling_coef (
Tensor
) – Image-domain coefficients to compensate for interpolation errors.im_size (
Tensor
) – Size of image with length being the number of dimensions.grid_size (
Tensor
) – Size of grid to use for interpolation, typically 1.25 to 2 timesim_size
.interp_mats (
Tuple
[Tensor
,Tensor
]) – 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation.norm (
Optional
[str
]) – Whether to apply normalization with the FFT operation. Options are"ortho"
orNone
.
- Return type
- Returns
data
transformed to an image.
kb_table_nufft
- kb_table_nufft(image, scaling_coef, im_size, grid_size, omega, tables, n_shift, numpoints, table_oversamp, offsets, norm=None)[source]
Kaiser-Bessel NUFFT with table interpolation.
See
KbNufft
for an overall description of the forward NUFFT.- Parameters
image (
Tensor
) – Image to be NUFFT’d to scattered data.scaling_coef (
Tensor
) – Image-domain coefficients to pre-compensate for interpolation errors.im_size (
Tensor
) – Size of image with length being the number of dimensions.grid_size (
Tensor
) – Size of grid to use for interpolation, typically 1.25 to 2 timesim_size
.omega (
Tensor
) – k-space trajectory (in radians/voxel).tables (
List
[Tensor
]) – Interpolation tables (one table for each dimension).n_shift (
Tensor
) – Size for fftshift, usuallyim_size // 2
.numpoints (
Tensor
) – Number of neighbors to use for interpolation.table_oversamp (
Tensor
) – Table oversampling factor.offsets (
Tensor
) – A list of offsets, looping over all possible combinations of numpoints.norm (
Optional
[str
]) – Whether to apply normalization with the FFT operation. Options are"ortho"
orNone
.
- Return type
- Returns
image
calculated at scattered Fourier locations.
kb_table_nufft_adjoint
- kb_table_nufft_adjoint(data, scaling_coef, im_size, grid_size, omega, tables, n_shift, numpoints, table_oversamp, offsets, norm=None)[source]
Kaiser-Bessel NUFFT adjoint with table interpolation.
See
KbNufftAdjoint
for an overall description of the adjoint NUFFT.- Parameters
data (
Tensor
) – Scattered data to be iNUFFT’d to an image.scaling_coef (
Tensor
) – Image-domain coefficients to compensate for interpolation errors.im_size (
Tensor
) – Size of image with length being the number of dimensions.grid_size (
Tensor
) – Size of grid to use for interpolation, typically 1.25 to 2 timesim_size
.omega (
Tensor
) – k-space trajectory (in radians/voxel).tables (
List
[Tensor
]) – Interpolation tables (one table for each dimension).n_shift (
Tensor
) – Size for fftshift, usuallyim_size // 2
.numpoints (
Tensor
) – Number of neighbors to use for interpolation.table_oversamp (
Tensor
) – Table oversampling factor.offsets (
Tensor
) – A list of offsets, looping over all possible combinations of numpoints.norm (
Optional
[str
]) – Whether to apply normalization with the FFT operation. Options are"ortho"
orNone
.
- Return type
- Returns
data
transformed to an image.