import torch
import numpy as np
from IFTS.fiber_simulation.base.base_optics import Optics_Base_Module

class Step_Size(Optics_Base_Module):
    """
        This class is for calculating the step size of performing the SSFM algorithm.
        
        Class Step_Size is derived from Optics_Base_Module. Users can choose different schemes
        to calculate step sizes.The schemes supported are as follows:
         (1) constant step-size method
         (2) maximum nonlinear phase-rotation method
         (3) logarithmic distribution method  
         (4) nonlinear phase-rotation combining with logarithmic distribution. 
        Step_Size object will be created automatically when a SSFM object is created.    
    """
    def __init__(self, dz_mode, nl_gamma, nPol, pmd, alpha_loss, span_len, *args, **kwargs):
        """
        Initialize the object.

        Parameters
        ----------
        dz_mode : str,{'np','log','np_log','c'}
           The name of method for calculating step size. 
           The supported method and the correspoding dz_mode value are as follows:
           {'c': constant step-size method
           'np': nonlinear phase-rotation method
           'log' : logarithmic distribution method 
           'np_log': nonlinear phase-rotation combining with logarithmic distribution
           }
        nl_gamma : float
           The nonlinear coefficient.
        nPol : int
           The number of polarization.
        pmd : int
           Determine if PMD is considered. The value is 0 or 1.
        alpha_loss:float
           The fiber loss coefficient(Np/km).
        span_len:float
           The length of the current span(km).
        *args, **kwargs:
           Other variable arguments including PMD config information. 
        """
        self.nPol = nPol
        self.dz_mode = dz_mode
        self.alpha_loss = alpha_loss
        self.nl_gamma = nl_gamma
        self.step_num = 0
        self.span_len = span_len
        self.pmd = pmd 
        self.dz_now = 0.0
        self.dz_previous = 0.0
        self.prop_dz = 0.0
        self.last_prop = False
        self.dz_max = kwargs.get('dz_max', None)
        self.phi_max = kwargs.get('phi_max', 0.005)
        self.total_step_num = kwargs.get('total_step_num', 1000)
        self.constant_step_size = kwargs.get('constant_step_size', 0.01)
        self.Constant_pi = kwargs.get('Constant_pi', np.pi)
        if self.pmd:
            self.pmd_prop_dz = 0.0  # The PMD propagation length record
            self.pmd_dz_arr = kwargs.get('pmd_dz_arr') # trunk length list,the sum equal to span length
            self.prop_dz_in_trunck = 0.0 #The length having been covered in the current trunk
            self.trunk_idx = 0 # index of the current trunk
        
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
    
    def step(self, x, final_dz = False):
        """Calculate the size of a single step.

        Calculate the size of the next step as well as the linear and 
        nonlinear operation distances (dz_l,dz_nl) for the next time,
        called in ssfm_scalar or ssfm_matrix. 

        Parameters
        ----------
        x : list
           The transmitted signal.
        final_dz : bool
           A flag to tell whether the next step is the last step of the span. 
          
        Returns
        -------
        None.
        """
        if not final_dz:
            self.dz_previous = self.dz_now
            self.step_num += 1
            if self.dz_mode == 'np':  # choose the method according to dz_mode
                self.__nonlinear_phase_step__(x)  
            elif self.dz_mode == 'log':
                self.__log_step__()
            elif self.dz_mode == 'np_log':
                self.__np_log_step__(x)
            elif self.dz_mode == 'c':
                self.__constant_step__()
            if self.dz_max is not None:
                if self.dz_now > self.dz_max:
                    self.dz_max = self.dz_now
            if self.dz_now + self.prop_dz >= self.span_len: 
                self.dz_now = self.span_len - self.prop_dz
                self.last_prop = True
            self.prop_dz = self.dz_now + self.prop_dz 
            # merge the dispersion operation distance in two steps
            self.dz_l = (self.dz_now + self.dz_previous) / 2  
            # calculate the effective operating length of the nonlinear effect in the prensence of the fiber loss
            if self.alpha_loss == 0:
                self.dz_nl = self.dz_now
            else:
                self.dz_nl = (1 - np.exp(- self.alpha_loss * self.dz_now)) / self.alpha_loss 
            # DBP:
            # dz_eff = - (1 - np.exp(fiber_para.alpha_loss * dz_nl)) / fiber_para.alpha_loss
        else:
            # dispersion operation distance is half of the step at the end of a span
            self.dz_l = self.dz_now / 2
        if self.pmd:
            self.__check_step_size__() 
        else:
            self.trunk_list = None
            self.trunk_idx_list = None

    def __calcu_power__(self, x):
        """Calculate the maximum power of the signal.

        Notes
        -----
        The power unit is W.
        """
        if self.nPol > 1:
            if type(x[0]) == np.ndarray:
                p_max = np.max(np.abs(x[0]) ** 2 + np.abs(x[1]) ** 2) 
            else:
                p_max = torch.max((x[0]).abs() ** 2 + (x[1]).abs() ** 2).item()

        else: 
            if type(x[0]) == np.ndarray:
                p_max = np.max(np.abs(x) ** 2) 
            else:
                p_max = torch.max(torch.sum((x[0]) ** 2)).item()
            
        return p_max

    def __nonlinear_phase_step__(self, x):
        """Apply the nonlinear phase-rotation method to calculate the step size.

        The method first gets the current maximum power of the transmitted signal, 
        then calculates the step size dz_now under the maximum nonlinear phase rotation 
        constraint.
        
        Parameters
        ----------
        x : list
          The transmitted signal.
        """
        p_max = self.__calcu_power__(x)
        gp_max = self.nl_gamma * p_max    
        self.dz_now = self.phi_max / gp_max

    def __log_step__(self):
        """Apply the logarithmic distribution method to calculate step size."""
        sigma = (1 - np.exp(- 2 * self.alpha_loss * self.span_len)) / self.total_step_num
        self.dz_now = -1 / (2 * self.alpha_loss) * np.log( (1 - self.step_num * sigma) / (1 - (self.step_num - 1) * sigma) )   
       
    def __np_log_step__(self, x):
        """Combine the nonlinear phase and the logarithmic method."""
        p_max = self.__calcu_power__(x)
        gp_max = self.nl_gamma * p_max  
        l_eff = self.phi_max / gp_max
        if self.alpha_loss == 0:
            self.dz_now = l_eff
        else:
            self.dz_now = - 1 / self.alpha_loss * np.log(1 - self.alpha_loss * l_eff)  
    
    def __constant_step__(self):
        self.dz_now = self.constant_step_size 

    def __check_step_size__(self):
        """Check the step size and update the record of PMD simulation steps.

        Given the linear opreation distance dz_l,the method evaluates the birefringence steps inside.
        The whole span is divided into several trunks to model the random briefringence in fiber channel,
        and the PMD coefficient is considered constant in each trunk.
        The method checks if the next step is to step over different trunks and accordingly update 
        the trunk index and the propagation length in a new trunk.
        
        """
        trunk_len = self.pmd_dz_arr[self.trunk_idx]  # length of the current trunk
        trunk_idx = self.trunk_idx   # index of the current trunk
        self.trunk_idx_list = [] # record of the indexes of trunks to be covered by dz_l
        self.trunk_list = []  # record of the length of the trunks covered
        
        prop_dz_in_trunck = self.prop_dz_in_trunck
        prop_dz = 0.0
        
        # evaluate the pmd trunk propagation state in the new dz_l and update record.
        while prop_dz < self.dz_l :
            trunk_len = self.pmd_dz_arr[trunk_idx]  
            dz = trunk_len - prop_dz_in_trunck
            # dz_l can't finish the current trunk
            if prop_dz + dz > self.dz_l: 
                dz = self.dz_l - prop_dz 
                prop_dz_in_trunck += dz
                self.trunk_idx_list.append(self.trunk_idx)
                self.trunk_list.append(dz)    
            # dz_l finish the current trunk and move on to a new one
            else: 
                prop_dz_in_trunck = 0.0
                trunk_idx += 1
                if trunk_idx < len(self.pmd_dz_arr) :
                    self.trunk_idx_list.append(self.trunk_idx)
                    self.trunk_list.append(dz)    
                else: # finish all trunks while dz_l not finish
                    trunk_idx -=1
                    dz = self.dz_l-prop_dz
                    self.trunk_list.append(self.dz_l-prop_dz)
                    self.trunk_idx_list.append(trunk_idx)
            prop_dz += dz
            self.trunk_idx = trunk_idx
        self.prop_dz_in_trunck = prop_dz_in_trunck
        self.pmd_prop_dz += np.sum(np.array(self.trunk_list)) #pmd_prop_dz: the total PMD propagetion length having been covered.
        