Source code for torchkbnufft._nufft.dcomp

from typing import Optional, Sequence, Union

import torch
import torchkbnufft.functional as tkbnF
from torch import Tensor

from .utils import init_fn

[docs]def calc_density_compensation_function( ktraj: Tensor, im_size: Sequence[int], num_iterations: int = 10, grid_size: Optional[Sequence[int]] = None, numpoints: Union[int, Sequence[int]] = 6, n_shift: Optional[Sequence[int]] = None, table_oversamp: Union[int, Sequence[int]] = 2 ** 10, kbwidth: float = 2.34, order: Union[float, Sequence[float]] = 0.0, ) -> Tensor: """Numerical density compensation estimation. This function has optional parameters for initializing a NUFFT object. See :py:class:`~torchkbnufft.KbInterp` for details. * :attr:`ktraj` should be of size ``(len(im_size), klength)``, where ``klength`` is the length of the k-space trajectory. Based on the `method of Pipe <;2-V>`_. This code was contributed by Chaithya G.R. Args: ktraj: k-space trajectory (in radians/voxel). im_size: Size of image with length being the number of dimensions. num_iterations: Number of iterations. grid_size: Size of grid to use for interpolation, typically 1.25 to 2 times ``im_size``. Default: ``2 * im_size`` numpoints: Number of neighbors to use for interpolation in each dimension. Default: ``6`` n_shift: Size for fftshift. Default: ``im_size // 2``. table_oversamp: Table oversampling factor. kbwidth: Size of Kaiser-Bessel kernel. order: Order of Kaiser-Bessel kernel. Returns: The density compensation coefficients for ``ktraj``. Examples: >>> data = torch.randn(1, 1, 12) + 1j * torch.randn(1, 1, 12) >>> omega = torch.rand(2, 12) * 2 * np.pi - np.pi >>> dcomp = tkbn.calc_density_compensation_function(omega, (8, 8)) >>> adjkb_ob = tkbn.KbNufftAdjoint(im_size=(8, 8)) >>> image = adjkb_ob(data * dcomp, omega) """ device = ktraj.device # init nufft variables ( tables, _, grid_size_t, n_shift_t, numpoints_t, offsets_t, table_oversamp_t, _, _, ) = init_fn( im_size=im_size, grid_size=grid_size, numpoints=numpoints, n_shift=n_shift, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, dtype=ktraj.dtype, device=device, ) test_sig = torch.ones([1, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device) for _ in range(num_iterations): new_sig = tkbnF.kb_table_interp( tkbnF.kb_table_interp_adjoint( data=test_sig, omega=ktraj, tables=tables, n_shift=n_shift_t, numpoints=numpoints_t, table_oversamp=table_oversamp_t, offsets=offsets_t, grid_size=grid_size_t, ), omega=ktraj, tables=tables, n_shift=n_shift_t, numpoints=numpoints_t, table_oversamp=table_oversamp_t, offsets=offsets_t, ) norm_new_sig = torch.abs(new_sig) test_sig = test_sig / norm_new_sig return test_sig