from typing import Optional, Sequence, Tuple, Union
import torch
import torchkbnufft.functional as tkbnF
from torch import Tensor
from .._nufft.utils import compute_scaling_coefs
from ._kbmodule import KbModule
class KbNufftModule(KbModule):
"""Parent class for KbNufft classes.
See subclasses for an explanation of inputs.
"""
def __init__(
self,
im_size: Sequence[int],
grid_size: Optional[Sequence[int]] = None,
numpoints: Union[int, Sequence[int]] = 6,
n_shift: Optional[Sequence[int]] = None,
table_oversamp: Union[int, Sequence[int]] = 2 ** 10,
kbwidth: float = 2.34,
order: Union[float, Sequence[float]] = 0.0,
dtype: torch.dtype = None,
device: torch.device = None,
):
super().__init__(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints,
n_shift=n_shift,
table_oversamp=table_oversamp,
kbwidth=kbwidth,
order=order,
dtype=dtype,
device=device,
)
scaling_coef = compute_scaling_coefs(
im_size=self.im_size.tolist(), # type: ignore
grid_size=self.grid_size.tolist(), # type: ignore
numpoints=self.numpoints.tolist(), # type: ignore
alpha=self.alpha.tolist(), # type: ignore
order=self.order.tolist(), # type: ignore
)
self.register_buffer(
"scaling_coef",
scaling_coef.to(dtype=self.table_0.dtype, device=device), # type: ignore
)
[docs]class KbNufft(KbNufftModule):
r"""Non-uniform FFT layer.
This object applies the FFT and interpolates a grid of Fourier data to
off-grid locations using a Kaiser-Bessel kernel. Mathematically, in one
dimension it estimates :math:`Y_m, m \in [0, ..., M-1]` at frequency
locations :math:`\omega_m` from :math:`X_k, k \in [0, ..., K-1]`, the
oversampled DFT of :math:`x_n, n \in [0, ..., N-1]`. To perform the
estimate, this layer applies
.. math::
X_k = \sum_{n=0}^{N-1} s_n x_n e^{-i \gamma k n}
.. math::
Y_m = \sum_{j=1}^J X_{\{k_m+j\}_K} u^*_j(\omega_m)
In the first step, an image-domain signal :math:`x_n` is converted to a
gridded, oversampled frequency-domain signal, :math:`X_k`. The scaling
coefficeints :math:`s_n` are multiplied to precompensate for NUFFT
interpolation errors. The oversampling coefficient is
:math:`\gamma = 2\pi / K, K >= N`.
In the second step, :math:`u`, the Kaiser-Bessel kernel, is used to
estimate :math:`X_k` at off-grid frequency locations :math:`\omega_m`.
:math:`k_m` is the index of the root offset of nearest samples of :math:`X`
to frequency location :math:`\omega_m`, and :math:`J` is the number of
nearest neighbors to use from :math:`X_k`. Multiple dimensions are handled
separably. For a detailed description see
`Nonuniform fast Fourier transforms using min-max interpolation
(JA Fessler and BP Sutton)
<https://doi.org/10.1109/TSP.2002.807005>`_.
When called, the parameters of this class define properties of the kernel
and how the interpolation is applied.
* :attr:`im_size` is the size of the base image, analagous to :math:`N`.
* :attr:`grid_size` is the size of the grid after forward FFT, analogous
to :math:`K`. To reduce errors, NUFFT operations are done on an
oversampled grid to reduce interpolation distances. This will typically
be 1.25 to 2 times :attr:`im_size`.
* :attr:`numpoints` is the number of nearest neighbors to use
for interpolation, i.e., :math:`J`.
* :attr:`n_shift` is the FFT shift distance, typically
:attr:`im_size // 2`.
Args:
im_size: Size of image with length being the number of dimensions.
grid_size: Size of grid to use for interpolation, typically 1.25 to 2
times ``im_size``. Default: ``2 * im_size``
numpoints: Number of neighbors to use for interpolation in each
dimension.
n_shift: Size for ``fftshift``. Default: ``im_size // 2``.
table_oversamp: Table oversampling factor.
kbwidth: Size of Kaiser-Bessel kernel.
order: Order of Kaiser-Bessel kernel.
dtype: Data type for tensor buffers. Default:
``torch.get_default_dtype()``
device: Which device to create tensors on.
Default: ``torch.device('cpu')``
Examples:
>>> image = torch.randn(1, 1, 8, 8) + 1j * torch.randn(1, 1, 8, 8)
>>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi
>>> kb_ob = tkbn.KbNufft(im_size=(8, 8))
>>> data = kb_ob(image, omega)
"""
[docs] def forward(
self,
image: Tensor,
omega: Tensor,
interp_mats: Optional[Tuple[Tensor, Tensor]] = None,
smaps: Optional[Tensor] = None,
norm: Optional[str] = None,
) -> Tensor:
"""Apply FFT and interpolate from gridded data to scattered data.
Input tensors should be of shape ``(N, C) + im_size``, where ``N`` is
the batch size and ``C`` is the number of sensitivity coils. ``omega``,
the k-space trajectory, should be of size ``(len(grid_size), klength)``
or ``(N, len(grid_size), klength)``, where ``klength`` is the length of
the k-space trajectory.
Note:
If the batch dimension is included in ``omega``, the interpolator
will parallelize over the batch dimension. This is efficient for
many small trajectories that might occur in dynamic imaging
settings.
If your tensors are real, ensure that 2 is the size of the last
dimension.
Args:
image: Object to calculate off-grid Fourier samples from.
omega: k-space trajectory (in radians/voxel).
interp_mats: 2-tuple of real, imaginary sparse matrices to use for
sparse matrix NUFFT interpolation (overrides default table
interpolation).
smaps: Sensitivity maps. If input, these will be multiplied before
the forward NUFFT.
norm: Whether to apply normalization with the FFT operation.
Options are ``"ortho"`` or ``None``.
Returns:
``image`` calculated at Fourier frequencies specified by ``omega``.
"""
if smaps is not None:
if not smaps.dtype == image.dtype:
raise TypeError("image dtype does not match smaps dtype.")
is_complex = True
if not image.is_complex():
if not image.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
if smaps is not None:
if not smaps.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
smaps = torch.view_as_complex(smaps)
is_complex = False
image = torch.view_as_complex(image)
if smaps is not None:
image = image * smaps
if interp_mats is not None:
assert isinstance(self.scaling_coef, Tensor)
assert isinstance(self.im_size, Tensor)
assert isinstance(self.grid_size, Tensor)
output = tkbnF.kb_spmat_nufft(
image=image,
scaling_coef=self.scaling_coef,
im_size=self.im_size,
grid_size=self.grid_size,
interp_mats=interp_mats,
norm=norm,
)
else:
tables = []
for i in range(len(self.im_size)): # type: ignore
tables.append(getattr(self, f"table_{i}"))
assert isinstance(self.scaling_coef, Tensor)
assert isinstance(self.im_size, Tensor)
assert isinstance(self.grid_size, Tensor)
assert isinstance(self.n_shift, Tensor)
assert isinstance(self.numpoints, Tensor)
assert isinstance(self.table_oversamp, Tensor)
assert isinstance(self.offsets, Tensor)
output = tkbnF.kb_table_nufft(
image=image,
scaling_coef=self.scaling_coef,
im_size=self.im_size,
grid_size=self.grid_size,
omega=omega,
tables=tables,
n_shift=self.n_shift,
numpoints=self.numpoints,
table_oversamp=self.table_oversamp,
offsets=self.offsets.to(torch.long),
norm=norm,
)
if not is_complex:
output = torch.view_as_real(output)
return output
[docs]class KbNufftAdjoint(KbNufftModule):
r"""Non-uniform FFT adjoint layer.
This object interpolates off-grid Fourier data to on-grid locations
using a Kaiser-Bessel kernel prior to inverse DFT. Mathematically, in one
dimension it estimates :math:`x_n, n \in [0, ..., N-1]` from a off-grid
signal :math:`Y_m, m \in [0, ..., M-1]` where the off-grid frequency
locations are :math:`\omega_m`. To perform the estimate, this layer applies
.. math::
X_k = \sum_{j=1}^J \sum_{m=0}^{M-1} Y_m u_j(\omega_m)
\mathbb{1}_{\{\{k_m+j\}_K=k\}},
.. math::
x_n = s_n^* \sum_{k=0}^{K-1} X_k e^{i \gamma k n}
In the first step, :math:`u`, the Kaiser-Bessel kernel, is used to
estimate :math:`Y` at on-grid frequency locations from locations at
:math:`\omega`. :math:`k_m` is the index of the root offset of nearest
samples of :math:`X` to frequency location :math:`\omega_m`,
:math:`\mathbb{1}` is an indicator function, and :math:`J` is the number of
nearest neighbors to use from :math:`X_k, k \in [0, ..., K-1]`.
In the second step, an image-domain signal :math:`x_n` is estimated from a
gridded, oversampled frequency-domain signal, :math:`X_k` by applying the
inverse FFT, after which the complex conjugate scaling coefficients
:math:`s_n` are multiplied. The oversampling coefficient is
:math:`\gamma = 2\pi / K, K >= N`. Multiple dimensions are handled
separably. For a detailed description see
`Nonuniform fast Fourier transforms using min-max interpolation
(JA Fessler and BP Sutton)
<https://doi.org/10.1109/TSP.2002.807005>`_.
Note:
This function is not the inverse of :py:class:`KbNufft`; it is the
adjoint.
When called, the parameters of this class define properties of the kernel
and how the interpolation is applied.
* :attr:`im_size` is the size of the base image, analagous to :math:`N`.
* :attr:`grid_size` is the size of the grid after adjoint interpolation,
analogous to :math:`K`. To reduce errors, NUFFT operations are done on an
oversampled grid to reduce interpolation distances. This will typically
be 1.25 to 2 times :attr:`im_size`.
* :attr:`numpoints` is the number of nearest neighbors to use for
interpolation, i.e., :math:`J`.
* :attr:`n_shift` is the FFT shift distance, typically
:attr:`im_size // 2`.
Args:
im_size: Size of image with length being the number of dimensions.
grid_size: Size of grid to use for interpolation, typically 1.25 to 2
times ``im_size``. Default: ``2 * im_size``
numpoints: Number of neighbors to use for interpolation in each
dimension.
n_shift: Size for ``fftshift``. Default: ``im_size // 2``.
table_oversamp: Table oversampling factor.
kbwidth: Size of Kaiser-Bessel kernel.
order: Order of Kaiser-Bessel kernel.
dtype: Data type for tensor buffers. Default:
``torch.get_default_dtype()``
device: Which device to create tensors on.
Default: ``torch.device('cpu')``
Examples:
>>> data = torch.randn(1, 1, 12) + 1j * torch.randn(1, 1, 12)
>>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi
>>> adjkb_ob = tkbn.KbNufftAdjoint(im_size=(8, 8))
>>> image = adjkb_ob(data, omega)
"""
[docs] def forward(
self,
data: Tensor,
omega: Tensor,
interp_mats: Optional[Tuple[Tensor, Tensor]] = None,
smaps: Optional[Tensor] = None,
norm: Optional[str] = None,
) -> Tensor:
"""Interpolate from scattered data to gridded data and then iFFT.
Input tensors should be of shape ``(N, C) + klength``, where ``N`` is
the batch size and ``C`` is the number of sensitivity coils. ``omega``,
the k-space trajectory, should be of size ``(len(grid_size), klength)``
or ``(N, len(grid_size), klength)``, where ``klength`` is the length of
the k-space trajectory.
Note:
If the batch dimension is included in ``omega``, the interpolator
will parallelize over the batch dimension. This is efficient for
many small trajectories that might occur in dynamic imaging
settings.
If your tensors are real, ensure that 2 is the size of the last
dimension.
Args:
data: Data to be gridded and then inverse FFT'd.
omega: k-space trajectory (in radians/voxel).
interp_mats: 2-tuple of real, imaginary sparse matrices to use for
sparse matrix NUFFT interpolation (overrides default table
interpolation).
smaps: Sensitivity maps. If input, these will be multiplied before
the forward NUFFT.
norm: Whether to apply normalization with the FFT operation.
Options are ``"ortho"`` or ``None``.
Returns:
``data`` transformed to the image domain.
"""
if smaps is not None:
if not smaps.dtype == data.dtype:
raise TypeError("data dtype does not match smaps dtype.")
is_complex = True
if not data.is_complex():
if not data.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
if smaps is not None:
if not smaps.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
smaps = torch.view_as_complex(smaps)
is_complex = False
data = torch.view_as_complex(data)
if interp_mats is not None:
assert isinstance(self.scaling_coef, Tensor)
assert isinstance(self.im_size, Tensor)
assert isinstance(self.grid_size, Tensor)
output = tkbnF.kb_spmat_nufft_adjoint(
data=data,
scaling_coef=self.scaling_coef,
im_size=self.im_size,
grid_size=self.grid_size,
interp_mats=interp_mats,
norm=norm,
)
else:
tables = []
for i in range(len(self.im_size)): # type: ignore
tables.append(getattr(self, f"table_{i}"))
assert isinstance(self.scaling_coef, Tensor)
assert isinstance(self.im_size, Tensor)
assert isinstance(self.grid_size, Tensor)
assert isinstance(self.n_shift, Tensor)
assert isinstance(self.numpoints, Tensor)
assert isinstance(self.table_oversamp, Tensor)
assert isinstance(self.offsets, Tensor)
output = tkbnF.kb_table_nufft_adjoint(
data=data,
scaling_coef=self.scaling_coef,
im_size=self.im_size,
grid_size=self.grid_size,
omega=omega,
tables=tables,
n_shift=self.n_shift,
numpoints=self.numpoints,
table_oversamp=self.table_oversamp,
offsets=self.offsets.to(torch.long),
norm=norm,
)
if smaps is not None:
output = torch.sum(output * smaps.conj(), dim=1, keepdim=True)
if not is_complex:
output = torch.view_as_real(output)
return output
[docs]class ToepNufft(torch.nn.Module):
r"""Forward/backward NUFFT with Toeplitz embedding.
This module applies :math:`Tx`, where :math:`T` is a matrix such that
:math:`T \approx A'A`, where :math:`A` is a NUFFT matrix. Using Toeplitz
embedding, this module approximates the :math:`A'A` operation without
interpolations, which is extremely fast.
The module is intended to be used in combination with an FFT kernel
computed as frequency response of an embedded Toeplitz matrix.
You can use :py:meth:`~torchkbnufft.calc_toeplitz_kernel` to calculate the
kernel.
The FFT kernel should be passed to this module's forward operation, which
applies a (zero-padded) FFT filter using the kernel.
Examples:
>>> image = torch.randn(1, 1, 8, 8) + 1j * torch.randn(1, 1, 8, 8)
>>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi
>>> toep_ob = tkbn.ToepNufft()
>>> kernel = tkbn.calc_toeplitz_kernel(omega, im_size=(8, 8))
>>> image = toep_ob(image, kernel)
"""
def __init__(self):
super().__init__()
def toep_batch_loop(
self, image: Tensor, smaps: Tensor, kernel: Tensor, norm: Optional[str]
) -> Tensor:
output = []
if len(kernel.shape) > len(image.shape[2:]):
# run with batching for kernel
for (mini_image, smap, mini_kernel) in zip(image, smaps, kernel):
mini_image = mini_image.unsqueeze(0) * smap.unsqueeze(0)
mini_image = tkbnF.fft_filter(
image=mini_image, kernel=mini_kernel, norm=norm
)
mini_image = torch.sum(
mini_image * smap.unsqueeze(0).conj(),
dim=1,
keepdim=True,
)
output.append(mini_image.squeeze(0))
else:
for (mini_image, smap) in zip(image, smaps):
mini_image = mini_image.unsqueeze(0) * smap.unsqueeze(0)
mini_image = tkbnF.fft_filter(
image=mini_image, kernel=kernel, norm=norm
)
mini_image = torch.sum(
mini_image * smap.unsqueeze(0).conj(),
dim=1,
keepdim=True,
)
output.append(mini_image.squeeze(0))
return torch.stack(output)
[docs] def forward(
self,
image: Tensor,
kernel: Tensor,
smaps: Optional[Tensor] = None,
norm: Optional[str] = "ortho",
) -> Tensor:
"""Toeplitz NUFFT forward function.
Args:
image: The image to apply the forward/backward Toeplitz-embedded
NUFFT to.
kernel: The filter response taking into account Toeplitz embedding.
norm: Whether to apply normalization with the FFT operation.
Options are ``"ortho"`` or ``None``.
Returns:
``image`` after applying the Toeplitz forward/backward NUFFT.
"""
if not kernel.dtype == image.dtype:
raise TypeError("kernel and image must have same dtype.")
if smaps is not None:
if not smaps.dtype == image.dtype:
raise TypeError("image dtype does not match smaps dtype.")
is_complex = True
if not image.is_complex():
if not image.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
if not kernel.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
if smaps is not None:
if not smaps.shape[-1] == 2:
raise ValueError("For real inputs, last dimension must be size 2.")
smaps = torch.view_as_complex(smaps)
is_complex = False
image = torch.view_as_complex(image)
kernel = torch.view_as_complex(kernel)
if len(kernel.shape) > len(image.shape[2:]):
if kernel.shape[0] == 1:
kernel = kernel[0]
elif not kernel.shape[0] == image.shape[0]:
raise ValueError(
"If using batch dimension, "
"kernel must have same batch size as image"
)
if smaps is None:
output = tkbnF.fft_filter(image=image, kernel=kernel, norm=norm)
else:
output = self.toep_batch_loop(
image=image, smaps=smaps, kernel=kernel, norm=norm
)
if not is_complex:
output = torch.view_as_real(output)
return output