import numpy as np
import torch 
from ..utils.tenosr_op import conv1d
from ..comm_tools.fftconv import conv
from scipy.ndimage import uniform_filter1d

def movmean(x, n, weight = None):
    '''
    M = movmean(A,k) returns an array of local k-point mean values, 
    where each mean is calculated over a sliding window of length k across neighboring elements of A. 
    When k is odd, the window is centered about the element in the current position. 
    When k is even, the window is centered about the current and previous elements. 
    The window size is automatically truncated at the endpoints when there are not enough elements to fill the window. 
    When the window is truncated, the average is taken over only the elements that fill the window. 
    M is the same size as A.
    For N-D arrays, MOVMEAN operates along the first array dimension whose size does not equal 1.
    Parameters:
    x: array_like
    Input array
    n: int
    Sliding window length
    weight : array_like, optional
    Calculation weight. The default is all 1 array

    Returns:
    out: array_like
    Array of average values obtained after sliding average
    '''
    if n % 2 == 0 :
        a = int(n/2) - 1
    else:
        a = int(n/2)
    denom_a = np.arange(a) + (n-a)
    b = n - a - 1
    denom_b = np.arange(b) + (n-b)
    
    if isinstance(x, np.ndarray):
        if weight is None:
            weight = np.ones(n) / n
        y2 = np.convolve(x, weight, 'valid')
        start = np.cumsum(x[:n-1])[-a:] / denom_a
        stop = (np.cumsum(x[-n+1:][::-1])[-b:] / denom_b)[::-1]
        out = np.concatenate((start, y2, stop))
    else:
        # out = torch.zeros_like(x) + 0.0 
        if weight is None:
            weight = torch.ones(n, device = x.device) / n
        y2 = conv1d(x, weight)
        start = torch.cumsum(x[:n-1])[-a:] / denom_a
        stop = (torch.cumsum(x[-n+1:][::-1])[-b:] / denom_b)[::-1]
        out = torch.cat((start, y2, stop))

    return out


def movmean_cycle(x,n):
    '''
    The circular moving average function returns an array of local average values 
    of n points. When there are not enough elements to fill the window, the window 
    joins to the beginning of the array.
    Parameters:
    x: array_like
    Input array
    n: int
    Sliding window length

    Returns:
    out: array_like
    Array of average values obtained after sliding average
    '''
    if isinstance(x, np.ndarray):
        out = np.zeros_like(x) + 0.0 
    else:
        out = torch.zeros_like(x) + 0.0 

    dim_len = x.shape[0]
    for i in range(dim_len):
        if n % 2 == 0:
            a, b = i - (n-1)//2, i + (n-1)//2 + 2
        else:
            a, b = i - (n-1)//2, i + (n-1)//2 + 1

        #cap indices to min and max indices
        a = max(0, a)
        b = min(dim_len, b)
        if isinstance(x, np.ndarray):
            out[i,...] = np.mean(x[a:b,...], axis=0)
        else:
            out[i,...] = torch.mean(x[a:b,...], dim=0)
    return out


def movmean_fast(x, n, weight = None):
    '''
    The time complexity of this moving average function is lower, and its 
    function is the same as that of 'movmean' function.
    Parameters:
    x: array_like
    Input array
    n: int
    Sliding window length
    weight : array_like, optional
    Calculation weight. The default is all 1 array

    Returns:
    out: array_like
    Array of average values obtained after sliding average
    '''
    if n % 2 == 0 :
        a = int(n/2) - 1
    else:
        a = int(n/2)
    if type(x) == np.ndarray:
        denom_a = np.arange(a) + (n-a)
        b = n - a - 1
        denom_b = np.arange(b) + (n-b)
        if weight is None:
            size = np.array(x.shape)
            size[-1] = n
            weight = np.ones(size) / n
        y2 = conv(x, weight, 'valid', axis = -1)
        start = np.cumsum(x[..., :n-1], axis = -1)[..., -a:] / denom_a
        stop = (np.cumsum(x[..., -n+1:][..., ::-1], axis = -1)[..., -b:] / denom_b)[..., ::-1]
        out = np.concatenate((start, y2, stop), axis=-1)
    else:
        denom_a = torch.arange(a, device = x.device) + (n-a)
        b = n - a - 1
        denom_b = torch.arange(b, device = x.device) + (n-b)
        if weight is None:
            size = np.array(x.shape)
            size[-1] = n
            weight = torch.ones(tuple(size), device = x.device) / n
        y2 = conv(x, weight, 'valid', axis = -1)
        start = torch.cumsum(x[..., :n-1], dim = -1)[..., -a:] / denom_a
        stop = torch.flip(torch.cumsum(torch.flip(x[..., -n+1:], dims = (-1,)), dim = -1)[..., -b:] / denom_b, dims = (-1,))
        out = torch.cat((start, y2, stop), dim = -1)
    
    return out
