import torch
import numpy as np
from scipy.fft import next_fast_len


def _centered(arr, newshape):
    '''
    Return the center newshape portion of the array.
    This function is used by `fft_convolve` to remove the zero padded region of the convolution.

    Note: If the array shape is odd and the target is even, the center of `arr` is shifted to the
    center-right pixel position. This is slightly different than the scipy implementation, which 
    uses the center-left pixel for the array center. The reason for the difference is that we have 
    adopted the convention of `np.fft.fftshift` in order to make sure that changing back and forth 
    from fft standard order (0 frequency and position is in the bottom left) to 0 position in the center.
    Parameters:
    arr: array_like
    Input array
    newshape: array_like
    Integer or integer array, enter the new shape that the array is converted to

    Returns:
    new_array: array_like
    New array obtained
    '''
    newshape = np.asarray(newshape)
    currshape = np.array(arr.shape)
    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    new_array = arr[tuple(myslice)]
    return new_array

def _freq_domain_conv(in1, in2, axis, shape, calc_fast_len=False):
    """
    Convolve two arrays in the frequency domain.
    This function implements only base the FFT-related operations.
    Specifically, it converts the signals to the frequency domain, multiplies
    them, then converts them back to the time domain.  Calculations of axes,
    shapes, convolution mode, etc. are implemented in higher level-functions,
    such as `fftconvolve` and `oaconvolve`.  Those functions should be used
    instead of this one.

    Parameters
    ----------
    in1 : array_like
        First input.
    in2 : array_like
        Second input. Should have the same number of dimensions as `in1`.
    axis : array_like of ints
        Axes over which to compute the FFTs.
    shape : array_like of ints
        The sizes of the FFTs.
    calc_fast_len : bool, optional
        If `True`, set each value of `shape` to the next fast FFT length.
        Default is `False`, use `axis` as-is.

    Returns
    -------
    ret : array
        An N-dimensional array containing the discrete linear convolution of
        `in1` with `in2`.

    """
    if type(in1) == np.ndarray:
        complex_result = (in1.dtype.kind == 'c' or in2.dtype.kind == 'c')
        if complex_result:
            fft, ifft = np.fft.fft, np.fft.ifft
        else:
            fft, ifft = np.fft.rfft, np.fft.irfft
    else:
        complex_result = torch.is_complex(in1) or torch.is_complex(in2)
        if complex_result:
            fft, ifft = torch.fft.fft, torch.fft.ifft
        else:
            fft, ifft = torch.fft.rfft, torch.fft.irfft

    if calc_fast_len:
        # Speed up FFT by padding to optimal size.
        fshape = next_fast_len(shape[axis], not complex_result)
    else:
        fshape = shape[axis]

    sp1 = fft(in1, fshape, axis)
    sp2 = fft(in2, fshape, axis)
    ret = ifft(sp1 * sp2, fshape, axis)

    if calc_fast_len:
        fslice = tuple([slice(sz) for sz in shape])
        ret = ret[fslice]

    return ret

def _apply_conv_mode(ret, s1, s2, mode, axes):
    """Calculate the convolution result shape based on the `mode` argument.
    Returns the result sliced to the correct size for the given mode.
    Parameters
    ----------
    ret : array
        The result array, with the appropriate shape for the 'full' mode.
    s1 : list of int
        The shape of the first input.
    s2 : list of int
        The shape of the second input.
    mode : str {'full', 'valid', 'same'}
        A string indicating the size of the output.
        See the documentation `fftconvolve` for more information.
    axes : list of ints
        Axes over which to compute the convolution.

    Returns
    -------
    ret : array
        A copy of `res`, sliced to the correct size for the given `mode`.
    """
    if mode == "full":
        return ret
    elif mode == "same":
        return _centered(ret, s1)
    elif mode == "valid":
        shape_valid = [ret.shape[a] if a not in axes else s1[a] - s2[a] + 1
                       for a in range(ret.ndim)]
        return _centered(ret, shape_valid)
    else:
        raise ValueError("acceptable mode flags are 'valid',"
                         " 'same', or 'full'")


def conv(inputs, kernel, mode="full", axis = -1):
    """
    Convolve two N-dimensional arrays using FFT.

    Convolve `in1` and `in2` using the fast Fourier transform method, with
    the output size determined by the `mode` argument.

    This is generally much faster than `convolve` for large arrays (n > ~500),
    but can be slower when only a few output values are needed, and can only
    output float arrays (int or object array inputs will be cast to float).

    Parameters
    ----------
    in1 : array_like
        First input.
    in2 : array_like
        Second input. Should have the same number of dimensions as `in1`.
    mode : str {'full', 'valid', 'same'}, optional
        A string indicating the size of the output:

        ``full``
           The output is the full discrete linear convolution
           of the inputs. (Default)
        ``valid``
           The output consists only of those elements that do not
           rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
           must be at least as large as the other in every dimension.
        ``same``
           The output is the same size as `in1`, centered
           with respect to the 'full' output.
    axes : int or array_like of ints or None, optional
        Axes over which to compute the convolution.
        The default is over all axes.

    Returns
    -------
    out : array
        An N-dimensional array containing a subset of the discrete linear
        convolution of `in1` with `in2`.

    See Also
    --------
    convolve : Uses the direct convolution or FFT convolution algorithm
               depending on which is faster.
    oaconvolve : Uses the overlap-add method to do convolution, which is
                 generally faster when the input arrays are large and
                 significantly different in size.

    Examples
    --------
    Autocorrelation of white noise is an impulse.

    >>> from scipy import signal
    >>> rng = np.random.default_rng()
    >>> sig = rng.standard_normal(1000)
    >>> autocorr = signal.fftconvolve(sig, sig[::-1], mode='full')

    >>> import matplotlib.pyplot as plt
    >>> fig, (ax_orig, ax_mag) = plt.subplots(2, 1)
    >>> ax_orig.plot(sig)
    >>> ax_orig.set_title('White noise')
    >>> ax_mag.plot(np.arange(-len(sig)+1,len(sig)), autocorr)
    >>> ax_mag.set_title('Autocorrelation')
    >>> fig.tight_layout()
    >>> fig.show()

    Gaussian blur implemented using FFT convolution.  Notice the dark borders
    around the image, due to the zero-padding beyond its boundaries.
    The `convolve2d` function allows for other types of image boundaries,
    but is far slower.

    >>> from scipy import misc
    >>> face = misc.face(gray=True)
    >>> kernel = np.outer(signal.windows.gaussian(70, 8),
    ...                   signal.windows.gaussian(70, 8))
    >>> blurred = signal.fftconvolve(face, kernel, mode='same')

    >>> fig, (ax_orig, ax_kernel, ax_blurred) = plt.subplots(3, 1,
    ...                                                      figsize=(6, 15))
    >>> ax_orig.imshow(face, cmap='gray')
    >>> ax_orig.set_title('Original')
    >>> ax_orig.set_axis_off()
    >>> ax_kernel.imshow(kernel, cmap='gray')
    >>> ax_kernel.set_title('Gaussian kernel')
    >>> ax_kernel.set_axis_off()
    >>> ax_blurred.imshow(blurred, cmap='gray')
    >>> ax_blurred.set_title('Blurred')
    >>> ax_blurred.set_axis_off()
    >>> fig.show()

    """
    if inputs.ndim == kernel.ndim == 0:  # scalar inputs
        return inputs * kernel
    elif inputs.ndim != kernel.ndim:
        raise ValueError("in1 and in2 should have the same dimensionality")
    elif inputs.size == 0 or kernel.size == 0:  # empty arrays
        return np.array([])
    elif inputs.shape[axis] < kernel.shape[axis]:
        tmp = inputs
        inputs = kernel
        kernel = tmp
    if axis == -1:
        axis = inputs.ndim - 1
    else:
        axis = axis 
    s1 = inputs.shape
    s2 = kernel.shape
    shape = np.array(s1)
    shape[axis] = s1[axis] + s2[axis] - 1
    
    ret = _freq_domain_conv(inputs, kernel, axis, shape, calc_fast_len=True)

    return _apply_conv_mode(ret, s1, s2, mode, [axis])