import torch
import numpy as np


def diff(x, axis = -1, data_mode = 'tensor'):
    '''
    Calculate the first discrete difference along the given axis.

    The first difference is given by ``out[n] = a[n+1] - a[n]`` along
    the given axis, higher differences can be calculated by using `diff`
    recursively.
    Parameters:
    x: array_like
    input data
    axis: int, optional
    Specifies the axis along which the discrete difference calculation is performed. The default is -1
    data_mode: str, optional
    Input data format (numpy or tensor)

    Returns:
    d: array_like
    Output discrete difference data
    '''
    if data_mode == 'numpy':
        return np.diff(x, axis = axis)
    else:
        kernel = torch.tensor([-1.0, 1.0], device = x.device).view(1,1,2)
        d = torch.nn.functional.conv1d(x.view(-1,1,x.shape[axis]), kernel, bias=None, stride=1, padding=0, dilation=1, groups=1)
        return d

def unwrap(p, discont=np.pi, axis = -1):
    """
    Unwrap by changing deltas between values to 2*pi complement.

    Unwrap radian phase `p` by changing absolute jumps greater than
    `discont` to their 2*pi complement along the given axis.

    Parameters
    ----------
    p : array_like
        Input array.
    discont : float, optional
        Maximum discontinuity between values, default is ``pi``.
    axis : int, optional
        Axis along which unwrap will operate, default is the last axis.

    Returns
    -------
    up : ndarray
        Output array.

    See Also
    --------
    rad2deg, deg2rad

    Notes
    -----
    If the discontinuity in `p` is smaller than ``pi``, but larger than
    `discont`, no unwrapping is done because taking the 2*pi complement
    would only make the discontinuity larger.

    Examples
    --------
    >>> phase = np.linspace(0, np.pi, num=5)
    >>> phase[3:] += np.pi
    >>> phase
    array([ 0.        ,  0.78539816,  1.57079633,  5.49778714,  6.28318531])
    >>> np.unwrap(phase)
    array([ 0.        ,  0.78539816,  1.57079633, -0.78539816,  0.        ])

    """
    if type(p) == np.ndarray:

        return np.unwrap(p, discont, axis)
    else:
        nd = p.ndim                         # signal dimension
        dd = diff(p, axis=axis).view(-1)             # 
        slice1 = [slice(None, None)]*nd     # full slices
        slice1[axis] = slice(1, None)
        slice1 = tuple(slice1)
        # p_ = p.cpu().numpy()
        # nd_ = p.ndim
        # # dd_ = np.diff(p_, axis=axis)
        dd_ = dd.cpu().detach().numpy()
        # ddmod_ = np.mod(dd_ + np.pi, 2*np.pi) - np.pi
        # # ddmod = torch.abs(torch.fmod(dd + np.pi, 2*np.pi) - np.pi)
        # # ddmod_ = ddmod.cpu().numpy()
        # np.copyto(ddmod_, np.pi, where=(ddmod_ == -np.pi) & (dd_ > 0))
        # ph_correct_ = ddmod_ - dd_
        # np.copyto(ph_correct_, 0, where=abs(dd_) < discont)
        # up_ = np.array(p_, copy=True, dtype='d')
        # up_[slice1] = p_[slice1] + ph_correct_.cumsum(axis)
        # ddmod = torch.abs(torch.fmod(dd + np.pi, 2*np.pi) - np.pi)
        ddmod_ = np.mod(dd_ + np.pi, 2*np.pi) - np.pi
        ddmod = torch.from_numpy(ddmod_).to(p.device)
        pi = torch.zeros_like(ddmod, device = p.device)
        ddmod = torch.where((ddmod == -pi) & (dd > 0), pi, ddmod)
        # _nx.copyto(ddmod, pi, where=(ddmod == -pi) & (dd > 0))
        ph_correct = ddmod - dd
        zero = torch.zeros_like(dd, device=p.device)
        ph_correct = torch.where(abs(dd) < discont, zero, ph_correct)
        # _nx.copyto(ph_correct, 0, where=abs(dd) < discont)
        up = torch.zeros_like(p, device = p.device)
        up.copy_(p)
        up[slice1] = p[slice1] + torch.cumsum(ph_correct, dim = axis)
        # plt.figure()
        # plt.plot(up_[0])
        # plt.savefig(figpath + 't1.png')
        # plt.close()
        # plt.figure()
        # plt.plot(up[0].cpu().numpy())
        # plt.savefig(figpath + 't2.png')
        # plt.close()
    return up