# @Author ：Minghui Shi, Zekun Niu
# @Datatime：2021/12/03
# @File: resample_fft.py
# @Last Modify Time: 2021/12/03
# @Contact : 1032263160@sjtu.edu.cn, zekunniu@sjtu.edu.cn

import torch
import numpy as np
from scipy import signal

def resample(x, rate):
    """
    resample(x,rate)
    Resample x according to rate using Fourier method along the given axis.
    In contrast to scipy.signal.resample, this function can resample signals of 
    type ndarray or tensor.

    Parameters:
    x : array_like
        The data to be resampled, its type can be tensor or ndarray.
    rate : float
        rate = after_sam//before_sam, which is the sampling ratio.

    Returns:
    y : array_like
        The resampled data.
    
    Raises:
    if the type of x is not tensor or ndarray, print 'Data type is not supported'.
   

    signal.resample(x,num)
    Resample x according to rate using Fourier method along the given axis.
    The resampled signal starts at the same value as `x` but is sampled
    with a spacing of ``len(x) / num * (spacing of x)``.  Because a
    Fourier method is used, the signal is assumed to be periodic.

    Parameters
    ----------
    x : array_like
        The data to be resampled.
    num : int
        The number of samples in the resampled signal.
    t : array_like, optional
        If `t` is given, it is assumed to be the equally spaced sample
        positions associated with the signal data in `x`.
    axis : int, optional
        The axis of `x` that is resampled.  Default is 0.
    window : array_like, callable, string, float, or tuple, optional
        Specifies the window applied to the signal in the Fourier
        domain.  See below for details.
    domain : string, optional
        A string indicating the domain of the input `x`:
        ``time`` Consider the input `x` as time-domain (Default),
        ``freq`` Consider the input `x` as frequency-domain.

    Returns
    -------
    resampled_x or (resampled_x, resampled_t)
        Either the resampled array, or, if `t` was given, a tuple
        containing the resampled array and the corresponding resampled
        positions.

    See Also
    --------
    decimate : Downsample the signal after applying an FIR or IIR filter.
    resample_poly : Resample using polyphase filtering and an FIR filter.

    Notes
    -----
    The argument `window` controls a Fourier-domain window that tapers
    the Fourier spectrum before zero-padding to alleviate ringing in
    the resampled values for sampled signals you didn't intend to be
    interpreted as band-limited.

    If `window` is a function, then it is called with a vector of inputs
    indicating the frequency bins (i.e. fftfreq(x.shape[axis]) ).

    If `window` is an array of the same length as `x.shape[axis]` it is
    assumed to be the window to be applied directly in the Fourier
    domain (with dc and low-frequency first).

    For any other type of `window`, the function `scipy.signal.get_window`
    is called to generate the window.

    The first sample of the returned vector is the same as the first
    sample of the input vector.  The spacing between samples is changed
    from ``dx`` to ``dx * len(x) / num``.

    If `t` is not None, then it is used solely to calculate the resampled
    positions `resampled_t`

    As noted, `resample` uses FFT transformations, which can be very
    slow if the number of input or output samples is large and prime;
    see `scipy.fft.fft`.
    The tensor verion is copied from scipy.signal.resample
    """
    
    if type(x) == torch.Tensor:
        num = x.shape[0]               # number of samples before resampling
        out_num = int(num * rate)      # number of samples after resampling
        X = torch.fft.fft(x)
        Y = torch.zeros(out_num, dtype=X.dtype).to(X.device)
        N = min(num, out_num)
        nyq = N//2+1                   # Nyquist frequency, which is half of the sampling frequency
        Y[0:nyq] = X[0:nyq]
        if N > 2:
            n = nyq - N
            Y[n:] = X[n:]
        if N % 2 == 0:
            if out_num < N:            # downsampling
                Y[-N//2] += X[-N//2]
            elif out_num > N:          # upsampling
                Y[N//2] *= 0.5
                temp = Y[N//2]
                Y[out_num-N//2] = temp

        # Inverse transform
        y = torch.fft.ifft(Y, out_num)
        y *= (float(out_num) / float(num))

        # x = x.numpy()
    elif type(x) == np.ndarray:
        num = x.shape[0]
        out_num = int(num*rate)
        y = signal.resample(x, out_num)
    
    else:
        raise AttributeError('Data type is not supported')
    
    return y

