Source code for torchkbnufft.functional.interp

from typing import List, Tuple

import torch
from torch import Tensor

from .._autograd.interp import (
    KbSpmatInterpAdjoint,
    KbSpmatInterpForward,
    KbTableInterpAdjoint,
    KbTableInterpForward,
)


[docs]def kb_spmat_interp(image: Tensor, interp_mats: Tuple[Tensor, Tensor]) -> Tensor: """Kaiser-Bessel sparse matrix interpolation. See :py:class:`~torchkbnufft.KbInterp` for an overall description of interpolation. To calculate the sparse matrix tuple, see :py:meth:`~torchkbnufft.calc_tensor_spmatrix`. Args: image: Gridded data to be interpolated to scattered data. interp_mats: 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation. Returns: ``image`` calculated at scattered locations. """ is_complex = True if not image.is_complex(): if not image.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") is_complex = False image = torch.view_as_complex(image) data = KbSpmatInterpForward.apply(image, interp_mats) if is_complex is False: data = torch.view_as_real(data) return data
[docs]def kb_spmat_interp_adjoint( data: Tensor, interp_mats: Tuple[Tensor, Tensor], grid_size: Tensor ) -> Tensor: """Kaiser-Bessel sparse matrix interpolation adjoint. See :py:class:`~torchkbnufft.KbInterpAdjoint` for an overall description of adjoint interpolation. To calculate the sparse matrix tuple, see :py:meth:`~torchkbnufft.calc_tensor_spmatrix`. Args: data: Scattered data to be interpolated to gridded data. interp_mats: 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation. Returns: ``data`` calculated at gridded locations. """ is_complex = True if not data.is_complex(): if not data.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") is_complex = False data = torch.view_as_complex(data) image = KbSpmatInterpAdjoint.apply(data, interp_mats, grid_size) if is_complex is False: image = torch.view_as_real(image) return image
[docs]def kb_table_interp( image: Tensor, omega: Tensor, tables: List[Tensor], n_shift: Tensor, numpoints: Tensor, table_oversamp: Tensor, offsets: Tensor, ) -> Tensor: """Kaiser-Bessel table interpolation. See :py:class:`~torchkbnufft.KbInterp` for an overall description of interpolation and how to construct the function arguments. Args: image: Gridded data to be interpolated to scattered data. omega: k-space trajectory (in radians/voxel). tables: Interpolation tables (one table for each dimension). n_shift: Size for fftshift, usually ``im_size // 2``. numpoints: Number of neighbors to use for interpolation. table_oversamp: Table oversampling factor. offsets: A list of offsets, looping over all possible combinations of ``numpoints``. Returns: ``image`` calculated at scattered locations. """ is_complex = True if not image.is_complex(): if not image.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") is_complex = False image = torch.view_as_complex(image) data = KbTableInterpForward.apply( image, omega, tables, n_shift, numpoints, table_oversamp, offsets ) if is_complex is False: data = torch.view_as_real(data) return data
[docs]def kb_table_interp_adjoint( data: Tensor, omega: Tensor, tables: List[Tensor], n_shift: Tensor, numpoints: Tensor, table_oversamp: Tensor, offsets: Tensor, grid_size: Tensor, ) -> Tensor: """Kaiser-Bessel table interpolation adjoint. See :py:class:`~torchkbnufft.KbInterpAdjoint` for an overall description of adjoint interpolation. Args: data: Scattered data to be interpolated to gridded data. omega: k-space trajectory (in radians/voxel). tables: Interpolation tables (one table for each dimension). n_shift: Size for fftshift, usually ``im_size // 2``. numpoints: Number of neighbors to use for interpolation. table_oversamp: Table oversampling factor. offsets: A list of offsets, looping over all possible combinations of ``numpoints``. grid_size: Size of grid to use for interpolation, typically 1.25 to 2 times ``im_size``. Returns: ``data`` calculated at gridded locations. """ is_complex = True if not data.is_complex(): if not data.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") is_complex = False data = torch.view_as_complex(data) image = KbTableInterpAdjoint.apply( data, omega, tables, n_shift, numpoints, table_oversamp, offsets, grid_size ) if is_complex is False: image = torch.view_as_real(image) return image