import torch
import numpy as np
from numpy import sign

def nonlinearity_matrix(sigin, nl_gamma, dz, Manakov = 1, data_mode = 'numpy'):
    """Calculate the nonlinear effect of the dual-polarization signal. 
    
    This function calculates the nonlinear effect in the propagation of the dual-polarization signal, 
    by solving the nonlinear part in either the Manakov Equation or coupled NLSE.

    Parameters
    ----------
    sigin : list
       The transmitted signal. 
    nl_gamma : float
       The nonlinear coefficient.
    dz : float
       The effective operating length of the nonlinear effect in this step.
       Note that the presence of the fiber loss has been considered.
    Manakov : int,{0，1}，optional
       Determine whether to solve the Manakov equation. Default:1
    data_mode:str,{'numpy','tensor'},optional
       The data type of the signal. Default:'numpy'
    
    Returns
    -------
    sigout : list
       The signal with nonlinearity in this step.

    Raises
    ------
    AttributeError
    When the value of data_mode is not legitimate.
    """
    sig_x = sigin[0]
    sig_y = sigin[1]
    if data_mode == 'tensor':
        power = torch.abs(sig_x) ** 2 + torch.abs(sig_y) ** 2
        nonlinear_act_x = sig_x * torch.exp(-1j * (nl_gamma * power * dz))
        nonlinear_act_y = sig_y * torch.exp(-1j * (nl_gamma * power * dz))

        if not Manakov:
            # solve CNLSE
            # does not suppport xpm
            s3 = 2 * (nonlinear_act_x.real * nonlinear_act_y.imag - nonlinear_act_x.imag * nonlinear_act_y.real) # stokes complex #3
            cosphi = torch.exp(- 1j * nl_gamma * s3 / 3).real
            sinphi = torch.exp(- 1j * nl_gamma * s3 / 3).imag
            nonlinear_act_x = cosphi * nonlinear_act_x + sinphi * nonlinear_act_y
            nonlinear_act_y = - sinphi * nonlinear_act_x + cosphi * nonlinear_act_y

    elif data_mode == 'numpy':
        power = np.abs(sig_x) ** 2 + np.abs(sig_y) ** 2
        nonlinear_act_x = sig_x * np.exp(- 1j * nl_gamma * power * dz)
        nonlinear_act_y = sig_y * np.exp(- 1j * nl_gamma * power * dz)       
        if not Manakov:
            # solve CNLSE
            # does not suppport xpm
            s3 = 2 * (nonlinear_act_x.real * nonlinear_act_y.imag - nonlinear_act_x.imag * nonlinear_act_y.real) # stokes complex #3
            cosphi = np.exp(1j * nl_gamma * s3 / 3).real
            sinphi = np.exp(1j * nl_gamma * s3 / 3).imag
            nonlinear_act_x = cosphi * nonlinear_act_x + sinphi * nonlinear_act_y
            nonlinear_act_y = - sinphi * nonlinear_act_x + cosphi * nonlinear_act_y
    else:
        raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
    sigout= [nonlinear_act_x, nonlinear_act_y] 
    return sigout
    
def linearity_matrix(sigin, dz, phase_factor_freq, pmd, differential_factor_freq= None,data_mode = 'numpy', **kwargs):
    """Calculate the dispersion effects of the dual-polarization signal.

    This function calculates the dispersion effects of the dual-polarization signal 
    which include GVD and PMD, and PMD is optional. First apply the Fourier transform 
    to both the dispersion operator and the signal, then add dispersion in the frequency 
    domain using matrix multiplication.

    Parameters
    ----------
    sigin : list
       The transmitted signal.
    dz : float
       The length of the dispersion needed to be calculated this time. 
    phase_factor_freq : list.
       The Fourier transform of the dispersion operator.
    pmd : int,{0,1}
       Determine if PMD is considered in simulation. 
    data_mode:str,{'numpy','tensor'},optional
       The data type of the signal. Default:'numpy'
    **kwargs : dict
       Variable paramters which are related to the simulation of the PMD effect.
    Returns
    -------
    sigout : list
       The signal with dispersion in this step.

    Raises
    ------
    AttributeError
    When the value of data_mode is not legitimate.

    RuntimeError
    When the total PMD length does not equal to the total SSFM operation length.
    """
    if data_mode == 'tensor':
        sig_fft_x = torch.fft.fft(sigin[0])
        sig_fft_y = torch.fft.fft(sigin[1])
        sig_fft_x = torch.fft.fftshift(sig_fft_x)
        sig_fft_y = torch.fft.fftshift(sig_fft_y)
        if pmd:
            # GVD + PMD
            trunk_idx_list = kwargs.get('trunk_idx_list')
            pmd_dz_arr = kwargs.get('pmd_dz_arr')
            trunk_list = kwargs.get('trunk_list')
            psp_theta = kwargs.get('psp_theta')
            psp_phi = kwargs.get('psp_phi')
            pmd_arr = kwargs.get('pmd_arr')
            sig0 = np.eye(2)
            sig2 = np.array([[0, 1], [1, 0]])
            sig3i = np.array([[0, 1], [-1, 0]])
            prop_dz = 0.0
            # pass the signal into PMD trunks in sequence
            for i in range(len(trunk_list)): # trunk_list: the lengths stepped by dz_l in different trunks
                pmd_dz = trunk_list[i] 
                mat_theta = np.cos(psp_theta[trunk_idx_list[i]]) * sig0\
                    - np.sin(psp_theta[trunk_idx_list[i]]) * sig3i + 0.0j   # orthogonal matrix
                mat_epsilon = np.cos(psp_phi[trunk_idx_list[i]]) * sig0\
                    + 1j * np.sin(psp_phi[trunk_idx_list[i]]) * sig2    # orthogonal
                mat_rot = mat_theta @ mat_epsilon  # mat_rot:matrix of change of basis over PSPs. @:matmul
                mat_rot = torch.from_numpy(mat_rot).to(sig_fft_x.device)
                mat_rot_conj = torch.conj(mat_rot)
                # Note: Calling A=[Ax;Ay] the electric field, we have that matR*D*matR'*A
                # is the linear PMD step, where D is the DGD matrix.
                # 1> move onto the PSPs basis by multiplying rotation matrix R
                uux = mat_rot_conj[0, 0] * sig_fft_x + mat_rot_conj[1, 0] * sig_fft_y
                uuy = mat_rot_conj[0, 1] * sig_fft_x + mat_rot_conj[1, 1] * sig_fft_y 
                # 2> apply birefringence, DGD and GVD: all in a diagonal matrix D    
                gvd_beta = phase_factor_freq * pmd_dz  # GVD beta factor 
                # pmd_beta = 0.5 * (fiber_para.db1) * fiber_para.dz_pmd[i] / fiber_para.l_corr 
                pmd_beta = sign(pmd_dz)* differential_factor_freq * np.sqrt(abs(pmd_dz))
                # Note: dzb(k)/brf.lcorr: fraction of DGD within current step dzb(k).
                uux = torch.exp(- 1j * (gvd_beta + pmd_beta)) * uux 
                uuy = torch.exp(- 1j * (gvd_beta - pmd_beta)) * uuy 
                # 3> come back in the original basis by multiplying the conj of the rotation matrix
                sig_fft_x = mat_rot[0, 0] * uux + mat_rot[0, 1] * uuy
                sig_fft_y = mat_rot[1, 0] * uux + mat_rot[1, 1] * uuy
                prop_dz += pmd_dz
            if prop_dz !=dz :
                raise RuntimeError('Toal PMD dz does not equal to SSFM dz')
        else: #  GVD only
            sig_fft_x = torch.exp(- 1j * phase_factor_freq * dz) * sig_fft_x 
            sig_fft_y = torch.exp(- 1j * phase_factor_freq * dz) * sig_fft_y   
        sig_fft_x = torch.fft.ifftshift(sig_fft_x) 
        sig_fft_y = torch.fft.ifftshift(sig_fft_y) 
        sig_x = torch.fft.ifft(sig_fft_x) 
        sig_y = torch.fft.ifft(sig_fft_y) 

    elif data_mode == 'numpy':
        sig_fft_x = np.fft.fft(sigin[0])
        sig_fft_y = np.fft.fft(sigin[1])
        sig_fft_x = np.fft.fftshift(sig_fft_x)
        sig_fft_y = np.fft.fftshift(sig_fft_y)
        if pmd:
            # GVD + PMD
            trunk_idx_list = kwargs.get('trunk_idx_list')
            pmd_dz_arr = kwargs.get('pmd_dz_arr')
            trunk_list = kwargs.get('trunk_list')
            psp_theta = kwargs.get('psp_theta')
            psp_phi = kwargs.get('psp_phi')
            pmd_arr = kwargs.get('pmd_arr')
            sig0 = np.eye(2)
            sig2 = np.array([[0, 1], [1, 0]])
            sig3i = np.array([[0, 1], [-1, 0]]) # = -j*sig3 = j * [0 ,-j; j 0]
            prop_dz = 0.0
            # pass the signal into PMD trunks in sequence
            for i in range(len(trunk_list)):
                pmd_dz = trunk_list[i]
                mat_theta = np.cos(psp_theta[trunk_idx_list[i]]) * sig0\
                    - np.sin(psp_theta[trunk_idx_list[i]]) * sig3i  + 0.0j   # orthogonal matrix
                mat_epsilon = np.cos(psp_phi[trunk_idx_list[i]]) * sig0\
                    + 1j * np.sin(psp_phi[trunk_idx_list[i]]) * sig2    # orthogonal
                mat_rot = mat_theta @ mat_epsilon # @:matmul,mat_rot:matrix of change of basis over the PSPs. 
                mat_rot_conj = np.conj(mat_rot)
                # Note: Calling A=[x;y] the electric field , we have that matR*D*matR'*A
                #   is the linear PMD step, where D is the diagonal matrix where the DGD operates.
                # 1> move onto the PSPs basis
                uux = mat_rot_conj[0, 0] * sig_fft_x + mat_rot_conj[1, 0] * sig_fft_y
                uuy = mat_rot_conj[0, 1] * sig_fft_x + mat_rot_conj[1, 1] * sig_fft_y 
                # 2> apply birefringence, DGD and GVD: all in a diagonal matrix    
                gvd_beta = phase_factor_freq * pmd_dz           # common beta factor
                # pmd_beta = 0.5 * (fiber_para.db1) * fiber_para.dz_pmd[i] / fiber_para.l_corr  # differential beta factor
                pmd_beta = sign(pmd_dz)* differential_factor_freq * np.sqrt(abs(pmd_dz))   # differential beta factor
                # Note: dzb(k)/brf.lcorr: fraction of DGD within current step dzb(k).
                uux = np.exp(- 1j * (gvd_beta + pmd_beta)) * uux 
                uuy = np.exp(- 1j * (gvd_beta - pmd_beta)) * uuy 
                # 3> come back in the original basis
                sig_fft_x = mat_rot[0, 0] * uux + mat_rot[0, 1] * uuy
                sig_fft_y = mat_rot[1, 0] * uux + mat_rot[1, 1] * uuy
                prop_dz += pmd_dz
            if prop_dz != dz:
                raise RuntimeError('Toal PMD dz does not equal to SSFM dz')
        else: #  GVD only
            sig_fft_x = np.exp(- 1j * phase_factor_freq * dz) * sig_fft_x 
            sig_fft_y = np.exp(- 1j * phase_factor_freq * dz) * sig_fft_y   
        sig_x = np.fft.ifft(sig_fft_x) 
        sig_y = np.fft.ifft(sig_fft_y)
        sig_x = np.fft.ifft(sig_fft_x) 
        sig_y = np.fft.ifft(sig_fft_y)  
    else:
        raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
    sigout = [sig_x, sig_y]
    return sigout