import torch
import numpy as np
import math
import IFTS.fiber_simulation.comm_tools.calculation as calcu
from IFTS.simulation_main.modul_para.signal_para import Sig_Para
from IFTS.fiber_simulation.channel.optics import wdm, edfa, coherent_receiver
from IFTS.fiber_simulation.channel.channel_trans.fiber.ssfm import ssfm_design
from IFTS.fiber_simulation.channel.channel_trans.fiber.nn import nn_design

class Ch_Para(Sig_Para):
    """
    Ch_Para is used to construct the simulation channel environment.

    Ch_Para inherits from Sig_Para. The attributes of Ch_Para include 
    basic signal parameters as well as parameters needed in constructing each
    channel submodule. Its methods create submodule objects to build up 
    the simulation channel.The integrity of the parameters needed in configuring 
    those objects will be checked first.

    Notes
    -----
    All attributes of a Ch_Para object are saved in the config file 
    'Channel_Para.yaml'.
    """
    def __init__(self, rand_seed, simu_configs):
        """This init method initializes the simulation channel environment.

        Different methods are called to create submodule objects according to 
        the set channel type. 
        The values of 'channel_type' and the corresponding types and methods 
        called are listed as follows:
          0 (back-to-back):fiber_para
          1 (fiber channel): fiber_para, edfa_para, wss_para(optional), receiver_para(optional)
          2 (AWGN channel): awgn_para

        Parameters
        ----------
        rand_seed : int
           Seed to control the random number generators used in the simulation of random process. 
        simu_configs: yaml config file
           The file including overall configuration information.
      
        Raises
        ------
        RuntimeError('channel_type is not supported')
        If the value of channel does not equal to 0,1 or 2.

        Warnings
        -------
        '    ch_para: you are using a simulation xxx channel'
        Inform of the type of channel being used.
      
        """
        super().__init__(rand_seed, simu_configs)  
        configs = simu_configs['Ch_Para']
        if self._check('ch_random', 0, configs):
            self._check('rand_seed', rand_seed + 501, configs)
        self.sam_rate       = self._check("ch_sam_rate") # unit:GHz
        self.awg_memory_depth = self._check("awg_memory_depth") 
        self.dt             = 1000 / self.sam_rate  # unit:ps
        self.upsam         = self.sam_rate / self._check("tx_sam_rate")
        self.fft_num        = self._check("ch_fft_num")
        self.simu_time      = self.dt * np.array(range(0, self.fft_num)) 
        self.infor_print    = self._check('infor_print_arr', np.ones(3))[2]
        self.fig_plot       = self._check('fig_plot_arr', np.ones(3))[2]
        self.save_data      = self._check('save_data_arr', np.zeros(3))[2]
        
        if self._check('data_mode', 'fine_tune') == 'fine_tune':
            self.ch_data_mode = 'tensor'   
            self.tx_laser_data_mode = 'numpy'
            self.rx_laser_data_mode = 'tensor'
            self.wdm_data_mode = 'numpy'              
            self.dwdm_data_mode = 'tensor'              
        else:
            self.ch_data_mode = self.data_mode      
            self.tx_laser_data_mode = self.data_mode
            self.wdm_data_mode = self.data_mode
            self.dwdm_data_mode = self.data_mode
        self._check('channel_type', 1, configs)    
        if self.channel_type != 0:
            if self._check('transmitter', 0, configs):
                self.__transmitter_para__()

            if self._check('wss', 1, configs):
                self._check('wss_config', configs = configs) # brickwall_filter, wss_filter 
                self.__wss_para__()
            # fiber channel
            if self.channel_type == 1:
                if self.infor_print:
                    print('    ch_para: you are using a simulation fiber channel')
                self._check('fiber_config', configs = configs)    
                self.__fiber_para__()
                self._check('edfa_config', configs = configs)    
                self.__edfa_para__()
                if self._check('do_cdm', 0, configs):    
                    self.__cdm__()
            # awgn channel
            elif self.channel_type == 2:
                self._check('awgn_config', configs = configs)    
                if self.infor_print:
                    print('    ch_para: you are using a simulation AWGN channel')
                self.__awgn_para__()
            else:
                raise RuntimeError('channel_type is not supported')
            if self._check('receiver', 1, configs):
                self._check('icr_config', configs = configs)
                self.__receiver_para__()
        else:
            if self.infor_print:
                print('    ch_para: you are using a back-to-back channel')
            self.__fiber_para__()

        self.save_configs(self.result_path + 'Channel_Para.yaml', configs)

 
    def __transmitter_para__(self): 
        raise NotImplementedError('Transmitter is not available at now version, please shut it down')

    def __wss_para__(self):
        """
        check if the parameters required for performing WDM simulation are all given 
        and create a WSS type object wdm_obj. 
        """
        self._config_check('wdm_data_mode', self.wdm_data_mode, self.wss_config)
        self._config_check('dwdm_data_mode', self.dwdm_data_mode, self.wss_config)
        self._config_check('device', self.device, self.wss_config)
        wss_args = self.wss_config['args']
        self._config_check('channel_num', self.channel_num, wss_args)
        self._config_check('channel_space', self.channel_space, wss_args)
        self._config_check('cut_idx', self.cut_idx, wss_args)
        self._config_check('sig_p', self.sig_power_dbm, wss_args)
        self._config_check('fft_num', self.fft_num, wss_args)
        self._config_check('sam_rate', self.sam_rate, wss_args)
        if self.wss_config['filter_type'] == 'brickwall_filter':
            pass
        elif self.wss_config['filter_type'] == 'wss_filter':
            self._config_check('bandwidth', 1, wss_args)
            self._config_check('steepness', 1, wss_args)
        self.wdm_obj = wdm.WSS(**self.wss_config) 


    def __edfa_para__(self):
        """
        check if the parameters needed in EDFA simulation are all given
        and create an EDFA type object edfa_obj.
        """
        alpha_indB = self._config_check('alpha_indB', configs = self.fiber_config)    # Fiber loss(dB/km)
        f_cut = self._config_check('optical_carrier', configs = self.fiber_config) # carrier frequency (GHz)
        self._config_check('noise_bw', self.sam_rate, self.edfa_config) # Noise bandwidth in simulation (Hz)
        self._config_check('gain', alpha_indB * self.span_len, self.edfa_config) # EDFA Gain, usually equals to fiber loss of each span (dB)
        self._config_check('f_cut', f_cut, self.edfa_config)
        self._config_check('nf_db', 5, self.edfa_config)
        self._config_check('data_mode', self.ch_data_mode, self.edfa_config)
        self._config_check('device', self.device, self.edfa_config)
        self.edfa_obj = edfa.EDFA(**self.edfa_config)

    def __cdm__(self):
        self.cd_manage_len = np.ones(self.span_num) * 80

    def __receiver_para__(self):
        """
        check if the config of the coherent receiver and create an ICR type object icr_obj.
        """
        self._config_check('data_mode', self.ch_data_mode, self.icr_config)
        self._config_check('device', self.device, self.icr_config)
        args = self.icr_config['args']
        self._config_check('upsam', self.upsam, args)
        self._config_check('add_noise', 1, args)
        self._config_check('n_power_dBm', - 20.6751, args) # reference value:[-20.6751 -21.6751 -21.7] dBm  
        self.icr_obj = coherent_receiver.ICR(**self.icr_config)

    def __awgn_para__(self):
        self._check('snr_db', self.awgn_config) 
    
    def __fiber_para__(self):
        """Config and calculate simulation parameters for optic fiber channel.

        This method is used to obtain some simulation parameters of the fiber channel
        use basic ones. The parameters required for simulating optic fiber channel including 
        span length, transmission loss and parameters of different channel effects. 

        Raises
        ------
        ValueError
        When the sum of the length of each span does not equal to total_len. 

        Warnings
        --------
        '    ch_para: SSFM is used'
        When SSFM is used to simulate signal transmission in optic fiber channel.
        """
        # span length config
        if self.load_len: # use the existing length config information  
            self._check('len_config_path', configs = self.fiber_config)
            self.len_config = self.read_configs(self.len_config_path)
            self.len_array = np.zeros((self.span_num))
            for i in range(self.span_num):
                self.len_array[i] = self.len_config['span' + str(i+1)]
            self.total_len = np.sum(self.len_array)
        else: # Calculate the span length according to the set total length and num of span. 
            self.span_num = int(self.total_len / self.span_len)
            if self.span_num * self.span_len != self.total_len:
                self.span_num = self.span_num + 1
                self.len_array = np.zeros(self.span_num) + self.span_len
                self.len_array[self.span_num-1] = self.total_len - (self.span_num - 1) * self.span_len
            else:
                self.len_array = np.zeros(self.span_num) + self.span_len
            if np.sum(self.len_array) != self.total_len:
                raise ValueError('The sum of span length is different with total length !')

        wavelength = self._check('wavelength', 1550, self.fiber_config) # wavelength (nm)=10^-9(m)
        self.optical_carrier = self._config_check('optical_carrier', self.constant_C / wavelength, self.fiber_config) # carrier frequency (GHz)
        
        # Dispersion parameters
        D = self._config_check('D', 17, self.fiber_config)    # D:the dispersion parameter (ps/nm.km),negative for dispersion compensation                                                    
        S = self._config_check('S', 0.075, self.fiber_config) # S:the dispersion slope (ps/(nm^2.km) 
        beta2c = self._config_check('beta2c',                 # beta_2c,beta_3c : the second-order and the third-order dispersion parameter            
            - 1000 * D * wavelength ** 2 / (2 * self.constant_pi * self.constant_C), self.fiber_config) 
        beta3c = self._config_check('beta3c', 0, self.fiber_config)
        # another option for beta_3c: 
        # self.beta3c = 10 ** 6 * (self.S - self.beta2c * (4 *self.constant_pi * C / self.wavelength ** 3) / 1000) / (2 *self.constant_pi * C / self.wavelength ** 2) ** 2  # (ps^3/km)
        # Calculate the frequency-dependent beta
        self.df = ((np.arange(1, self.channel_num + 1) - (self.channel_num + 1) / 2) * self.channel_space)
        self.beta0 = beta2c / 2 * (2 * self.constant_pi * self.df)**2 * 10**-6  + beta3c / 6 * (2 * self.constant_pi * self.df)**3 * 10**-9  
        self.beta1 = beta2c * 2 * self.constant_pi * self.df * 10**-3 + beta3c * (2 * self.constant_pi * self.df)**2 * 10**-6 / 2      
        self.beta2 = beta2c + 2 * self.constant_pi * self.df * beta3c * 10**-3
        self.beta3 = beta3c
        pmd = self._config_check('pmd', 1, self.fiber_config) 
              
        # Loss parameters
        alpha_indB = self._config_check('alpha_indB', 0.2, self.fiber_config) #unit:dB/km
        alpha_loss = math.log(10 ** (alpha_indB / 10)) # unit: 1/km
        self._config_check('alpha_loss', alpha_loss, self.fiber_config)
        
        # Nonlinearity parameters
        n2 = self._config_check('n2', 2.6e-20, self.fiber_config)
        Aeff = self._config_check('Aeff', 80, self.fiber_config)
        gamma = 1000 * 2 *self.constant_pi * n2 / (wavelength * 10 ** -9 * Aeff * 10**-12)  # gamma:nonlinear coefficient(km^-1.W^-1)
        self._config_check('gamma', gamma, self.fiber_config)
        self.fiber_mode = self.fiber_config['mode']
        if self.fiber_mode == 'SSFM':
            print('    ch_para: SSFM is used')
            self.ssfm_obj = ssfm_design.SSFM(len_arr = self.len_array, beta0 = self.beta0[self.cut_idx], beta1 = self.beta1[self.cut_idx],\
                beta2 = self.beta2[self.cut_idx], beta3 = self.beta3, pmd = pmd, alpha_loss = alpha_loss,\
                    gamma = gamma, sam_rate= self.sam_rate, fft_num = self.fft_num,\
                        infor_print = self.infor_print, data_mode = self.ch_data_mode, device = self.device)
            ssfm_args = self.fiber_config['args']
            self._config_check('isManakov', 1, ssfm_args)
            self._config_check('step_config', configs=ssfm_args)
            self._config_check('pmd_config', configs=ssfm_args)
            self.ssfm_obj.init(**ssfm_args)
        if self.fiber_mode == 'NN':
            print('    ch_para: NN is used')
            
            self.channel_num = self._check("channel_num")
            self.sym_rate = self._check("sym_rate")
            self.channel_space = self._check("channel_space")
            self.sig_power_dbm = self._check("sig_power_dbm")
            self.nn_model = self.fiber_config['nn_model']
            nn_model_path = './IFTS/fiber_simulation/channel/channel_trans/fiber/nn/nn_checkpoint/'
            nn_model_path = nn_model_path + self.nn_model + '_' + str(self.channel_num) + 'channels_' + str(self.sym_rate) + 'GBaud_' + str(self.channel_space) + 'Gchannel_space_' + str(self.sig_power_dbm) + 'dbm_'
            self.nn_obj = nn_design.NN(len_arr = self.len_array, beta0 = self.beta0[self.cut_idx], beta1 = self.beta1[self.cut_idx],\
                beta2 = self.beta2[self.cut_idx], beta3 = self.beta3, pmd = pmd, alpha_loss = alpha_loss,\
                    gamma = gamma, sam_rate= self.sam_rate, fft_num = self.fft_num,\
                        nn_model = self.nn_model, model_path = nn_model_path,\
                            infor_print = self.infor_print, data_mode = self.ch_data_mode, device = self.device)
            nn_args = self.fiber_config['args']
            self._config_check('pmd_config', configs=nn_args)
            self.nn_obj.init(**nn_args)
    def ssfm_para(self):
            
        '''
        beta(w+dwk)=b0 + b1k w + b2k w^2/2 + b3k w^3 /6 + ...
        b1k = d(beta(w))/dw = b2 dwk + b3 dwk^2 / 2
        b2k = d^2(beta(w))/dw^2
        '''                                         
        f = calcu.digital_freq(self.fft_num, self.sam_rate,\
            self.ch_data_mode, self.device)
        self.w = 2 *self.constant_pi * f
        self.phase_factor_freq = self.beta0[self.cut_idx]\
            + self.beta1[self.cut_idx] * (self.w)\
                + self.beta2[self.cut_idx] * (self.w ** 2) / 2\
                    + self.beta3 * (self.w ** 3) / 6         
    
        
        
        