Source code for torchkbnufft._nufft.dcomp

from typing import Optional, Sequence, Union

import torch
import torchkbnufft.functional as tkbnF
from torch import Tensor

from .utils import init_fn


[docs]def calc_density_compensation_function( ktraj: Tensor, im_size: Sequence[int], num_iterations: int = 10, grid_size: Optional[Sequence[int]] = None, numpoints: Union[int, Sequence[int]] = 6, n_shift: Optional[Sequence[int]] = None, table_oversamp: Union[int, Sequence[int]] = 2**10, kbwidth: float = 2.34, order: Union[float, Sequence[float]] = 0.0, ) -> Tensor: """Numerical density compensation estimation. This function has optional parameters for initializing a NUFFT object. See :py:class:`~torchkbnufft.KbInterp` for details. * :attr:`ktraj` should be of size ``(len(grid_size), klength)`` or ``(N, len(grid_size), klength)``, where ``klength`` is the length of the k-space trajectory. Based on the `method of Pipe <https://doi.org/10.1002/(SICI)1522-2594(199901)41:1%3C179::AID-MRM25%3E3.0.CO;2-V>`_. This code was contributed by Chaithya G.R. Args: ktraj: k-space trajectory (in radians/voxel). im_size: Size of image with length being the number of dimensions. num_iterations: Number of iterations. grid_size: Size of grid to use for interpolation, typically 1.25 to 2 times ``im_size``. Default: ``2 * im_size`` numpoints: Number of neighbors to use for interpolation in each dimension. Default: ``6`` n_shift: Size for fftshift. Default: ``im_size // 2``. table_oversamp: Table oversampling factor. kbwidth: Size of Kaiser-Bessel kernel. order: Order of Kaiser-Bessel kernel. Returns: The density compensation coefficients for ``ktraj``. Examples: >>> data = torch.randn(1, 1, 12) + 1j * torch.randn(1, 1, 12) >>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi >>> dcomp = tkbn.calc_density_compensation_function(omega, (8, 8)) >>> adjkb_ob = tkbn.KbNufftAdjoint(im_size=(8, 8)) >>> image = adjkb_ob(data * dcomp, omega) """ device = ktraj.device batch_size = 1 if ktraj.ndim not in (2, 3): raise ValueError("ktraj must have 2 or 3 dimensions") if ktraj.ndim == 3: if ktraj.shape[0] == 1: ktraj = ktraj[0] else: batch_size = ktraj.shape[0] # init nufft variables ( tables, _, grid_size_t, n_shift_t, numpoints_t, offsets_t, table_oversamp_t, _, _, ) = init_fn( im_size=im_size, grid_size=grid_size, numpoints=numpoints, n_shift=n_shift, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, dtype=ktraj.dtype, device=device, ) test_sig = torch.ones( [batch_size, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device ) for _ in range(num_iterations): new_sig = tkbnF.kb_table_interp( image=tkbnF.kb_table_interp_adjoint( data=test_sig, omega=ktraj, tables=tables, n_shift=n_shift_t, numpoints=numpoints_t, table_oversamp=table_oversamp_t, offsets=offsets_t, grid_size=grid_size_t, ), omega=ktraj, tables=tables, n_shift=n_shift_t, numpoints=numpoints_t, table_oversamp=table_oversamp_t, offsets=offsets_t, ) test_sig = test_sig / torch.abs(new_sig) return test_sig