To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 2af4137e authored by Jingwei Tang's avatar Jingwei Tang
Browse files

Change torch.fft and torch.ifft to torch.fft.fft and torch.fft.ifft

parent 14694a27
......@@ -13,7 +13,7 @@ class FourierDecomposition(object):
self.dim = 2 if self.solver.is_2d() else 3
def fft(self, grid, signal_ndim, normalized=True):
"""Wrapper for torch.fft to perform 2D/3D fft on input grid.
"""Wrapper for torch.fft.fft to perform 2D/3D fft on input grid.
Also shifts the zero freq to the center of the frequency domain.
Args:
......@@ -31,14 +31,14 @@ class FourierDecomposition(object):
)
grid = torch.stack((grid, zero_grid), dim=-1)
if signal_ndim < 4:
f = torch.fft(grid, signal_ndim=signal_ndim, normalized=normalized)
f = torch.fft.fft(grid, signal_ndim=signal_ndim, normalized=normalized)
else:
f = self._fft4_noshift(grid, normalized=normalized)
fshift = self._batch_fftshift(f)
return fshift
def ifft(self, fshift, signal_ndim, normalized=True):
"""Wrapper for torch.ifft to perform 2D/3D inverse fft on shifted frequency
"""Wrapper for torch.fft.ifft to perform 2D/3D inverse fft on shifted frequency
Args:
fshift(torch.Tensor): of size [C, (T), (D), H, W, 2]
......@@ -52,11 +52,10 @@ class FourierDecomposition(object):
assert signal_ndim > 0 and signal_ndim < 5
f = self._batch_ifftshift(fshift)
if signal_ndim < 4:
grid = torch.ifft(f, signal_ndim=signal_ndim, normalized=normalized)
grid = torch.fft.ifft(f, signal_ndim=signal_ndim, normalized=normalized)
else:
grid = self._ifft4_noshift(f, normalized=normalized)
return grid[..., 0]
# return grid.norm(p=2, dim=-1)
def get_spectral(self, fshift):
"""Computes the manitude of fourier frequency of given grid,
......@@ -177,8 +176,8 @@ class FourierDecomposition(object):
representing real and complex components. fftshift is
not performed.
"""
f_p = torch.fft(grid, signal_ndim=3, normalized=normalized)
f = torch.fft(
f_p = torch.fft.fft(grid, signal_ndim=3, normalized=normalized)
f = torch.fft.fft(
f_p.permute(0, 2, 3, 4, 1, 5), signal_ndim=1, normalized=normalized
).permute(0, 4, 1, 2, 3, 5)
# ############################## ALTERNATIVE ###########################
......@@ -205,8 +204,8 @@ class FourierDecomposition(object):
representing real and complex components. ifftshift is
not performed.
"""
x_p = torch.ifft(f, signal_ndim=3, normalized=normalized)
grid = torch.ifft(
x_p = torch.fft.ifft(f, signal_ndim=3, normalized=normalized)
grid = torch.fft.ifft(
x_p.permute(0, 2, 3, 4, 1, 5), signal_ndim=1, normalized=normalized
).permute(0, 4, 1, 2, 3, 5)
return grid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment