from matplotlib.pyplot import xlim
import torch
import numpy as np
from scipy.stats import maxwell
from ......fiber_simulation.utils.show_progress import progress_info_return
from .....base.base_optics import Optics_Base_Module
from .step_size import Step_Size
from .ssfm_matrix import linearity_matrix, nonlinearity_matrix
from .ssfm_scalar import linearity_scalar, nonlinearity_scalar
from ..... utils.define_freq import calcu_f

class SSFM(Optics_Base_Module):
    """
    Implement the split-step fourier method(SSFM) alogrithm in fiber simulation.

    This class uses split-step Fourier method (SSFM) to solve nonlinear Schrödinger 
    equation(NLSE) to simulate transmission in optical fiber channel. Its attributes 
    include parameters of different channel effects, and the methods are for calculating 
    step sizes and solving linear and nonlinear parts of the equation. One-span long 
    transmission will be completed every time a SSFM object is called.
      
    """
    def __init__(self, len_arr, beta0, beta1, beta2, beta3, pmd, alpha_loss, gamma,\
        fft_num, sam_rate, infor_print = 1, data_mode = 'numpy', *args, **kwargs):
        """
        Pass in parameters and initialize the object.

        The whole initialization process is completed in two steps using two init methods. 
        The basic __init__ method obtains basic parameters such as the sample rate and 
        the GVD parameters. The init method further completes detailed configuration. 
        

        Parameters
        ----------
        len_arr: ndarray
            The length of each span (km).
        beta0, beta1, beta2, beta3 : float
            Chromatic dispersion coefficients of zero-order to the third order(ps^2/km).
        pmd : int,{0,1}
            Determine whether to calculate PMD in simulation.
        alpha_loss : float
            The fiber loss parameter(Np/km).
        gamma:float
            The nonlinear coefficient(km^-1.W^-1).
        fft_num:int
            The number of FFT points,which normally equals to the length of input signal sequence.
        sam_rate:int
            The channel sample rate,which equals to channel numbers*symbol rates*4
        infor_print:int,{0,1},optional
            Determine whether to print the progress of performing the algorithm.Default:1
        data_mode:str,{'numpy','tensor'},optional
            The data type used in operation.Default:'numpy'.
        **kwargs : dict
               Other arguments.
               {'device','constant_pi','constant_h'} 
        
        """
       
        super().__init__()
        self.len_arr = len_arr
        self.beta0 = beta0
        self.beta1 = beta1
        self.beta2 = beta2 # the GVD parameter
        self.beta3 = beta3
        self.pmd = pmd
        self.alpha_loss = alpha_loss
        self.gamma = gamma
        self.sam_rate = sam_rate
        self.fft_num = fft_num
        self.infor_print = infor_print
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu') 
        self.constant_pi = kwargs.get('constant_pi', np.pi) 

        
    def init(self, Manakov, step_config, pmd_config, *args, **kwargs):
        """This init method further intialize the object using basic parameters.

        This method calcultes some parameters using basic parameters and config 
        PMD simulation parameters if PMD is to be considered.

        Parameters
        ----------
        Manakov: int,{0,1}
           Determine if Manakov equation is used, if Manakov=0, coupled nonlinear Schrödinger equation 
           (CNLSE)will be used instead.
        step_config : dict
           Contains parameters needed for step size calculation, which will be used to construct a Step_Size object. 
           Refer to the description of class Step_Size for details.
        pmd_config:dict
           Parameters needed in PMD simulation. Refer to __pmd_init__ for details.
        
        """
        
        self.w = 2 * self.constant_pi * calcu_f(self.fft_num, self.sam_rate * 1e-3)
        phase_factor_freq = self.beta0\
            + self.beta1 * (self.w)\
                + self.beta2 * (self.w ** 2) / 2\
                    + self.beta3 * (self.w ** 3) / 6  # The Fourier transform of the chromatic dispersion operator.
        self.phase_factor_freq = self.data_mode_convert(phase_factor_freq)
        self.Manakov = Manakov
        self.step_config = step_config
        # Config PMD arguments if PMD is to be considered in simulation
        if self.pmd:
            self.pmd_config = pmd_config
            self.dgd_manual = self._config_check('dgd_manual', 0, pmd_config)  
            self.psp_manual = self._config_check('psp_manual', 0, pmd_config)  
            self.pmd_coeff_random = self._config_check('pmd_coeff_random', 0, pmd_config)  
            self.pmd_dz_random = self._config_check('pmd_dz_random', 0, pmd_config)  
            for key in pmd_config:
                self.__dict__[key] = pmd_config[key] # pmd_coeff,pmd_trunk_num

            if self.dgd_manual: 
                self._config_check('dgd_total', 0.2, pmd_config)
            else:
                self.dgd_rms = np.sqrt(3 * self.constant_pi / 8 * self.pmd_coeff**2)

            if self.psp_manual:
                # The phase rotation of SOP is fixed.
                self.phi = self._config_check('phi', self.constant_pi / 4 , pmd_config)
        
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
    
    def forward_pass(self, sigin, span_len):
        """Perform the SSFM alogrithm.

        This method is a wrapper which will be automatically called when the object is called. 
        ssfm_scalar or ssfm_matrix will be called in this method to perform SSFM operation.

        Parameters
        ---------
        sigin : list
           The signal to be transmitted in this span.
        span_len : float
           The length of the current span(km).

        Returns
        -------
        sigout : list
           The signal after one span transmission.
        """
        # update the span_len
        self.span_len = span_len
        self.nPol = len(sigin)
        # power_in = (x[0].abs() ** 2 + x[1].abs() ** 2).mean().item()
        # power_in = 10 * np.log10(power_in / 0.001)
        if self.nPol == 1:
            # Single polarization
            sigout = self.__ssfm_scalar__(sigin) 
        elif self.nPol == 2:
            # Dual polarization
            sigout = self.__ssfm_matrix__(sigin)
        # power_out = (x[0].abs() ** 2 + x[1].abs() ** 2).mean().item()
        # power_out = 10 * np.log10(power_out / 0.001)
        return sigout

    def __calcu_p__(self, x):
        """Calculate the average signal power.

        Parameters
        ----------
        x : list
        The signal.

        Returns
        -------
        p : float
        The signal power(dBm).
        """
        p = 0
        if type(x) == np.ndarray:
            for i_p in range(len(x)):
                p += np.abs(x[i_p]) ** 2 
            p = np.mean(p)
        else:
            for i_p in range(len(x)):
                p += (x[i_p].abs()) ** 2
            p = p.mean().item()
        p = 10 * np.log10(p / 0.001)
        return p

    def __step_init__(self):
        """Config for step size calculation.
        
        The method creates a Step_Size object for step size calculation. 
        Arguments required are wrapped up as a configuration dict and passed in to
        intialize the object. Refer to Step_Size.__init__ for details.

        """
        self.step_config['nPol'] = self.nPol
        self.step_config['nl_gamma'] = self.nl_gamma
        self.step_config['alpha_loss'] = self.alpha_loss
        self.step_config['span_len'] = self.span_len
        self.step_config['pmd'] = self.pmd
        if self.pmd:
            self.step_config['pmd_dz_arr'] = self.pmd_dz_arr
        self.step_size = Step_Size(**self.step_config)

    def __ssfm_matrix__(self, sig):
        """Apply SSFM to simulate dual-polarization signal transmission.

        This method is to solve NLSE for dual-polarization signal.
        The simulation is executed step by step until the accumulated step sizes 
        reach the length of the span. Refer to the background for the principle of the algorithm.

        """
        self.nl_gamma = self.gamma
        if self.Manakov:
            self.nl_gamma = 8 / 9 * self.gamma
        if self.pmd:
            self.__pmd_init__()
        else:
            self.pmd_dz_arr = None
            self.pmd_arr = None
            self.psp_theta = None
            self.psp_phi = None
        self.__step_init__()
        for i_p in range(self.nPol):
            sig[i_p] = self.data_mode_convert(sig[i_p], self.data_mode)
        while(not self.step_size.last_prop):
            # 1> Calculate the step size.
            self.step_size.step(sig)
            # 2> Update the linear and nonlinear operation parameters 
            #    according to the current step size.
            self.__func_para_init__()
            # 3> Solve the linear and nonlinear part of NLSE respectively
            sig = linearity_matrix(sig, **self.lin_func_para)
            sig = nonlinearity_matrix(sig, **self.nlin_func_para)
            # 4> Add signal attenuation
            sig = [sig[0] * np.exp(- self.alpha_loss * self.step_size.dz_now / 2),\
                sig[1] * np.exp(- self.alpha_loss * self.step_size.dz_now / 2)]
        self.step_size.step(sig, final_dz = True)
        self.__func_para_init__()
        sigout = linearity_matrix(sig, **self.lin_func_para)
        return sigout

    def __ssfm_scalar__(self, sigin):
        """Apply SSFM to simulate single-polarization signal transmission.

        This method is to solve NLSE for single-polarization signal.
        The simulation is executed step by step until the accumulated step sizes 
        reach the length of the span. Refer to the background for the principle of the algorithm.
        
        Parameters
        ----------
        sigin : list
           The signal to be transmitted in this span.

        Returns
        -------
        sigout : list
           The signal after one span transmission.

        """
        self.nl_gamma = self.gamma
        self.__step_init__()
        while(not self.step_size.last_prop):
            # 1> Calculate the step size.
            self.step_size.step(sigin)
            # 2> Update the linear and nonlinear operation parameters 
            #    according to the current step size.
            self.__func_para_init__()
            # 3> Solve the linear and nonlinear part of NLSE respectively
            x = linearity_scalar(sigin, **self.lin_func_para)
            x = nonlinearity_scalar(x, **self.nlin_func_para)
            # 4> Add signal attenuation
            x = [x[0] * np.exp(- self.alpha_loss * self.step_size.dz_now / 2),\
                x[1] * np.exp(- self.alpha_loss * self.step_size.dz_now / 2)]
        self.step_size.step(x, final_dz = True)
        self.__func_para_init__()
        sigout = linearity_scalar(x, **self.lin_func_para)
        return sigout

    def __pmd_init__(self):
        """Config parameters for PMD simulation. """
        if self.dgd_manual:
            pmd_coeff = self.dgd_total / np.sqrt(self.span_len)
        else:
            if self.pmd_coeff_random:
                # sample the pmd coefficient from Maxwellian distribution
                vx = np.random.normal(loc = 0, scale = np.sqrt(self.dgd_rms**2/3))
                vy = np.random.normal(loc = 0, scale = np.sqrt(self.dgd_rms**2/3))
                vz = np.random.normal(loc = 0, scale = np.sqrt(self.dgd_rms**2/3))
                pmd_coeff = np.sqrt(vx**2 + vy**2 + vz**2)
            else:
                pmd_coeff = self.pmd_coeff
        l_corr = self.span_len / self.pmd_trunk_num 
        pmd_per_trunk =  pmd_coeff / np.sqrt(self.pmd_trunk_num)  # PMD coefficient per trunk.
        if self.pmd_dz_random:
            self.pmd_dz_arr = np.random.normal(loc = l_corr, scale = l_corr / 5 , size = (self.nplates))
        else:
            self.pmd_dz_arr = np.ones((self.pmd_trunk_num)) * l_corr # same pmd_dz 
        self.pmd_arr = pmd_per_trunk * np.ones(self.pmd_trunk_num)
        if self.psp_manual:
            self.psp_theta   = self.phi * np.ones(self.pmd_trunk_num)
            self.psp_phi     = self.phi * np.ones(self.pmd_trunk_num)
        else:
            self.psp_theta = np.random.rand(self.pmd_trunk_num) * 2 * np.pi - np.pi           # 均匀分布[0-1） azimuth: uniform R.V.
            self.psp_phi   = 0.5 * np.arcsin(np.random.rand(self.pmd_trunk_num) * 2 - 1)     # uniform R.V. over the Poincare sphere
    
    def __func_para_init__(self):
        """Wrap up parameters required in solving NLSE.

        This method updates the parameters required in solving linear and nonlinear parts 
        of the NLSE according to the current step size and wraps them up as two dict
        lin_func_para and nlin_func_para.
        """
        # Linearity function parameters
        self.lin_func_para = {}
        self.lin_func_para['dz'] = self.step_size.dz_l
        self.lin_func_para['phase_factor_freq'] = self.phase_factor_freq
        self.lin_func_para['pmd'] = self.pmd
        self.lin_func_para['pmd_dz_arr'] = self.pmd_dz_arr
        self.lin_func_para['pmd_arr'] = self.pmd_arr
        self.lin_func_para['psp_theta'] = self.psp_theta
        self.lin_func_para['psp_phi'] = self.psp_phi
        self.lin_func_para['trunk_list'] = self.step_size.trunk_list
        self.lin_func_para['plates_idx_list'] = self.step_size.plates_idx_list
        self.lin_func_para['data_mode'] = self.data_mode
        # Nonlinearity function parameters
        self.nlin_func_para = {}
        self.nlin_func_para['dz'] = self.step_size.dz_nl
        self.nlin_func_para['nl_gamma'] = self.nl_gamma
        self.nlin_func_para['Manakov'] = self.Manakov
        self.nlin_func_para['data_mode'] = self.data_mode