import torch
import numpy as np

def nonlinearity_scalar(sigin,  nl_gamma, dz, data_mode = 'numpy'):
    """Calculate the nonlinear effect of the single-polarization signal.
    
    This function calculates the nonlinear effect in the propagation of the 
    single-polarization signal by solving the nonlinear part of the NLSE.
    
    Parameters
    ----------
    sigin : list
       The transmitted signal.
    nl_gamma : float
       The nonlinear coefficient.
    dz: the operation length of the nonlinear effect in the current step.
    data_mode:str,{'numpy','tensor'},optional
        The data type of the signal. Default:'numpy'
    
    Returns
    -------
    sigout : list
        The signal with nonlinearity in this step.

    Raises
    ------
    AttributeError
    When the value of data_mode is not legitimate.
    """
    sig_x = sigin[0]
    if data_mode == 'tensor':
            power = torch.abs(sig_x) ** 2
            nonlinear_sig_x = sig_x * torch.exp(nl_gamma * (power) * dz)

    elif data_mode == 'numpy':
            power = np.abs(sig_x) ** 2
            nonlinear_sig_x = sig_x* np.exp(1j * nl_gamma * (power) * dz)
    else:
        raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
    nonlinear_sig=[nonlinear_sig_x]
    return nonlinear_sig
    
def linearity_scalar(sigin, dz, phase_factor_freq,data_mode = 'numpy'):
    """Calculate the dispersion effects of the single-polarization signal.

    This function calculates the dispersion effect,i.e.,GVD, of the single-polarization signal. 
    .First obtain the Fourier transform of both the dispersion operator and the signal, then add 
    dispersion in the frequency domain.

    Parameters
    ----------
    sigin : list
       The transmitted signal.
    dz : float
       The length of the dispersion needed to be calculated this time. 
    phase_factor_freq : list.
       The Fourier transform of the dispersion operator.
    pmd : int,{0,1}
       Determine if PMD is considered in simulation. 
    data_mode : str,{'numpy','tensor'},optional
       The data type of the signal. Default:'numpy'

    Returns
    -------
    sigout : list
       The signal with dispersion in this step.

    Raises
    ------
    AttributeError
    When the value of data_mode is not legitimate.
    """
    if data_mode == 'tensor':
        sig_fft_x = torch.fft.fft(sigin[0])
        sig_fft_x = torch.fft.fftshift(sig_fft_x)
        sig_fft_x = torch.exp(- 1j * phase_factor_freq * dz) * sig_fft_x
        sig_fft_x = torch.fft.ifftshift(sig_fft_x)
        sig_x = torch.fft.ifft(sig_fft_x) 

    elif data_mode == 'numpy':
        sig_fft_x = np.fft.fft(sigin[0])
        sig_fft_x = np.fft.fftshift(sig_fft_x)
        sig_fft_x = np.exp(- 1j * phase_factor_freq * dz) * sig_fft_x
        sig_fft_x = np.fft.ifftshift(sig_fft_x)
        sig_x = np.fft.ifft(sig_fft_x)

    else:
        raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
    sigout = [sig_x]
    return sigout
    
