import torch
import numpy as np
from ...base.base_optics import Optics_Base_Module

class EDFA(Optics_Base_Module):
    def __init__(self, mode, nf_db, f_cut, gain, noise_bw,\
         data_mode = 'numpy', **kwargs):
        """
        This is a class modeling the optic device erbium-doped fiber amplifier(EDFA).
            
        EDFA class is derived from Optics_Base_Module class. Its attributes include 
        parameters of the EDFA and its methods perform signal operation. An EDFA object 
        should be called to amplify the signal at the end of each span as in the real system. 
        Users can choose whether to add amplifier spontaneous emission (ASE) noise.
        
        ...

        Attributes
        ----------
        mode : str,{"naive_pass","no_noise_pass"}
               Mode of EDFA.
               naive_pass    :  Amplify the signal and add ASE noise to the amplified signal.
               no_noise_pass :  Only amplify the signal. 
        nf_db : float 
               The amplifier noise figure.
               Defined as the ratio between the input OSNR and the output OSNR (dB).
        f_cut : float
               The carrier frequency of the amplified signal (GHz).
        gain : float 
               The gain of an EDFA(unit-less). 
        noise_bw : float
               The bandwith of the ASE noise (GHz).
        **kwargs : dict
               Other arguments.
               {'device','constant_pi','constant_h'} 
        """
        self.mode = mode
        self.nf_db = nf_db  
        self.f_cut = f_cut  
        self.gain = 10 ** ((gain) / 10)  # convert gain[dB] to gain[unit-less]  
        self.noise_bw = noise_bw   
        self.data_mode = data_mode
        self.rand_seed = -1
        self.device = kwargs.get('device', 'cpu') 
        self.constant_pi = kwargs.get('constant_pi', np.pi) 
        self.constant_h = kwargs.get('constant_h', 6.626068e-34)
        self.__calcu_noise_power__()
    
    def forward_pass(self, sigin, rand_seed = None):
        """
        Let the signal pass an EDFA.
        
        This is a wrapper method invoked when calling the object. 

        Parameters
        ----------
        sigin : list
            The input signal.
        rand_seed : int, optional
            The seed to control the generator.Defaults to -1.     
            
        Returns
        -------
        sigout : list
            The output signal from the EDFA.

        """
        if self.mode == 'naive_pass':
            sigout = self.__naive_pass__(sigin, rand_seed)
        elif self.mode == 'no_noise_pass':
            sigout = self.__no_noise_pass__(sigin)
        return sigout

    def __naive_pass__(self, sigin, rand_seed = None):
        """Amplify the signal and add ASE noise"""
        sigout = []
        if rand_seed is None:
            rand_seed = self.rand_seed
        for i_p in range(len(sigin)):
            rx_sig = self.__amplifier__(sigin[i_p])
            rx_sig = self.__add_noise__(rx_sig, rand_seed + i_p)
            sigout.append(rx_sig)
        return sigout

    def __no_noise_pass__(self, sigin):
        """Amplify the signal without adding ASE noise"""
        sigout = []
        for i_p in range(len(sigin)):
            rx_sig = self.__amplifier__(sigin[i_p])
            sigout.append(rx_sig)
        return sigout

    def __calcu_noise_power__(self):
        """Calculate the noise power of ASE noise.

        Notes
        -----
        The power spectral density(PSD) of ASE is nearly constant(white noise) 
        which can be written as
        .. math:: PSD = 2*n_{sp}*constant_h*f_cut*(gain-1)*
        The parameter n_{sp} is the spontaneous emission factor 
        which can be approximated by :math: n_{sp}=nf/2 if gain>>1
        ..[1] Govind P. Agrawal.Fiber-Optic Communication Systems[M]
          .New Jersey:John Wiley&Sons,2010:305-307.  
        """
        self.nf = 10 ** (self.nf_db / 10)
        self.noise_psd = self.constant_h * self.f_cut * (10 ** 9) * ( self.gain - 1) * self.nf
        self.noise_power = self.noise_psd * self.noise_bw * (10 ** 9)

    def __amplifier__(self, sigin):
        """ Amplify the signal.

        Parameters
        -----
        sigin: tensor or ndarray
               The signal of one polarization to be amplified.

        Returns
        -------
        sigout: tensor or ndarray
                The amplified signal.

        Notes
        -----
        The gain of the EDFA is calculated by alpha*z, 
        The fiber attenuation :math: A(z)=A(0)e^{-\alpha *z/2},alpha in [Np/km]
        Hence gain of the signal equals to :math: e^{g/2} 
        """
        g = np.log(self.gain) #  # convert gain[unit-less] to gain[Np]  
        sigout = sigin*np.exp(g/2)
        #sigout = sigin * np.sqrt(self.gain)
        return sigout

    def __add_noise__(self, sig, rand_seed):
        """Add noise to the amplified signal.

        Parameters
        -----
        sig: list
            The amplified signal .
        rand_seed: int
            The seed to control the generator. 

        Returns
        -------
        sigout: list
            The amplified signal with ASE noise.

        Raises
        ------
        AttributeError
           If the value of data_mode is not legitimate.
        """
        if self.data_mode == 'numpy':
            if rand_seed == -1:
                rng_i = np.random.default_rng()
                rng_q = np.random.default_rng()
            else:
                rng_i = np.random.default_rng(rand_seed)
                rng_q = np.random.default_rng(rand_seed + 2)
            num = len(sig)
            n = np.sqrt(self.noise_power/2) * (rng_i.standard_normal(num)\
                + 1j * rng_q.standard_normal(num)) 
        elif self.data_mode =='tensor':
            rng_i = torch.Generator(device=self.device)
            rng_q = torch.Generator(device=self.device)
            if rand_seed == -1:
                rng_i.seed()
                rng_q.seed()
            else:
                rng_i = rng_i.manual_seed(rand_seed)            
                rng_q = rng_q.manual_seed(rand_seed + 2) 
            num = len(sig)
            n = np.sqrt(self.noise_power/2) * (torch.randn(num, generator = rng_i, device = self.device) \
                + 1j * torch.randn(num, generator = rng_q, device = self.device))
        else:
            raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
        sigout = sig + n
        return sigout