import numpy as np
import torch
from IFTS.simulation_main.modul_para.signal_para import Sig_Para 
from IFTS.fiber_simulation.rx_dsp.low_pass_filter import lpf_design
from IFTS.fiber_simulation.rx_dsp.adaptive_filter import adaptive_filter_design
from IFTS.fiber_simulation.rx_dsp.IQ_compensation import imbalance_eq
from IFTS.fiber_simulation.rx_dsp.carrier_phase_estimation import cpe_design
from IFTS.fiber_simulation.rx_dsp.cd_compensation import cdc_design
from IFTS.fiber_simulation.rx_dsp.synchronization import synchron_design

class Rx_para(Sig_Para):
    def __init__(self, rand_seed, simu_configs):
        r"""
            Check receiver's config parameters.
            Initialize functional modules.
            Save config results.
        """
        super().__init__(rand_seed, simu_configs)

        # Check receiver's config parameters
        configs = simu_configs['Rx_Para']
        self.infor_print    = self._check('infor_print_arr', np.ones(3))[3]
        self.fig_plot       = self._check('fig_plot_arr', np.ones(3))[3]
        self.save_data      = self._check('save_data_arr', np.zeros(3))[3]
        if self._check('data_mode', 'fine_tune') == 'fine_tune':
            self.rx_data_mode = 'tensor'               
        else:
            self.rx_data_mode = self.data_mode    # numpy, tensor
        self._check('fiber_config', configs=simu_configs['Ch_Para'])
        self.oscop_sam_num  = self._check('rx_fft_num', 2000000, configs)
        self.frame_num  = self._check('frame_num', 1, configs)
        self.sam_rate   = self._check("rx_sam_rate")    # GHz
        self.ch_sam_rate   = self._check("ch_sam_rate")    # GHz
        self.sym_rate   = self._check("sym_rate")       # GHz   ns = 1/GHz
        self.dt         = 1000 / self.sam_rate          # ps   ps = 1000 ns
        self.sym_time   = 1000 / self.sym_rate
        self.awg_memory_depth = self._check("awg_memory_depth")
        self.frame_sym_num = int(self.awg_memory_depth / (self.tx_sam_rate / self.sym_rate))
        self.sym_simu_time = np.arange(0, self.frame_sym_num) * self.sym_time
        self.sam_now    = self.sam_rate  / self.sym_rate

        r"""
            Coarse LPF
                |
            IQ Imbalance
                |            
            Chromatic Dispersion  Compensation
                |
            MIMO (or Muti-sample Nonlinearity Compensation)
                |
             Synchronization
                |
            Carrier Phase Estimation
        """

        # Initialize LPF
        if self._check('lpf', 1, configs):
            # brickwall, bessel, butter, gaussian, rc, rrc
            self._check('lpf_config', configs = configs) 
            self.__low_pass_filter__()

        # Initialize iq_balance
        if self._check('iq_balance', 1, configs):
            self._check('iq_balance_config', configs = configs)
            self.__iq_balance__()

        # Calculate chromatic dispersion length
        if self.load_len:
            self._check('len_config_path', configs = self.fiber_config)
            self.len_config = self.read_configs(self.len_config_path)
            self.len_array = np.zeros((self.span_num))
            for i in range(self.span_num):
                self.len_array[i] = self.len_config['span' + str(i+1)]
            self.total_len = np.sum(self.len_array)
        else:
            # Span length calculation, also can set different length in each span
            self.span_num = int(self.total_len / self.span_len)
            if self.span_num * self.span_len != self.total_len:
                self.span_num = self.span_num + 1
                self.len_array = np.zeros(self.span_num) + self.span_len
                self.len_array[self.span_num-1] = self.total_len - (self.span_num - 1) * self.span_len
            else:
                self.len_array = np.zeros(self.span_num) + self.span_len
            if np.sum(self.len_array) != self.total_len:
                raise ValueError('The sum of span length is different with total length !')

        # Initialize chromatic dispersion compensation
        if self._check('cdcom', 1, configs):
            pre_cd_len = 0    # cdc length = total_len - pre_cd_len
            self._check('cd_com_config', configs = configs) 
            self.__cdcom__(pre_cd_len)

        # Initialize mimo
        if self._check('mimo', 1, configs):
            self._check('mimo_config',  configs = configs)   
            self.__mimo__()

        # Initialize carrier phase estimation
        if self._check('cpe', 1, configs):
            self._check('cpe_config', configs = configs)   # bps, coarse
            self.__cpe__()

        # Initialize synchronization
        if self._check('synchronization', 1, configs):
            self._check('synchron_config', configs = configs)   # 2x2, 4x4
            self.__synchronization__()

        # Save config results
        self.save_configs(self.result_path + 'Rx_Para.yaml', configs)

    def __iq_balance__(self):
        # IQ balance initialization function
        mode = self.iq_balance_config['type']
        self.iq_balance_obj = imbalance_eq.IQ_Balance(mode, data_mode = 'tensor', device = self.device)

    def __low_pass_filter__(self):
        r"""
            LPF initialization function
            optional LPF type: brickwall, bessel, butter, gaussian, rc, rrc
        """

        # Check LPF's config parameters
        config = self.lpf_config
        self._config_check('mode', 'freq', config)   
        self._config_check('sam_rate', self.sam_rate, config)   
        self._config_check('fft_num', self.oscop_sam_num, config)   
        self._config_check('data_mode', 'tensor', config)   
        self._config_check('device', self.device, config)  
        self.lpf_obj = lpf_design.LPF(**config)
        lpf_args = config['args']
        max_lofo = self._config_check('max_lofo', 0.5, lpf_args)
        roll_off = self._config_check('roll_off', 0.1, lpf_args)
        if self.lpf_obj.mode == 'rrc':
            self._config_check('beta', roll_off, lpf_args)    # Roll-off zone width
            self._config_check('upsam', self.sam_per_sym, lpf_args)
            if max_lofo != 0:
                RuntimeWarning.warn('RRC will filter the signal becasue of LOFO at rx_para')
        else:
            filter_bw = self._config_check('filter_bw',\
                self.sym_rate * (1 + roll_off) / 2 + max_lofo,lpf_args)    # brickwall para
            if self.lpf_obj.mode == 'bessel':
                self._config_check('order', 15, lpf_args)     # bessel para
            elif self.lpf_obj.mode == 'butter':
                self._config_check('gpass', 5, lpf_args)     # butterworth para
                self._config_check('gstop', 40, lpf_args)    # butterworth para
            self._config_check('cut_off', filter_bw, lpf_args)
            if self.infor_print:
                print ("    rx_para LPF: {} filter and bandwidth {:.2f} GHz".format(config['mode'],filter_bw))
        self._config_check('comp_s21', 1, lpf_args)
        self._config_check('s21_bw', filter_bw, lpf_args)
        self._config_check('s21_path', configs=lpf_args)
        self.lpf_obj.init(**lpf_args)

    def __cdcom__(self, pre_cd_len):
        r"""
            CDC initialization function
            Calculate and check dispersion parameters
        """
        # wavelength (nm)=10^-9(m)
        wavelength = self._check('wavelength', 1550, self.fiber_config)

        # Calculate dispersion parameters
        D = self._config_check('D', 17, self.fiber_config)    # 16.75  dispersion Parameter (ps/nm.km); if anomalous dispersion(for compensation),D is negative
        S = self._config_check('S', 0.075, self.fiber_config)    # S slope ps/(nm^2.km)
        beta2c = self._config_check('beta2c',\
            - 1000 * D * wavelength ** 2 / (2 * self.constant_pi * self.constant_C), self.fiber_config) # beta2 (ps^2/km);
        beta3c = self._config_check('beta3c', 0, self.fiber_config)
        # self.beta3c = 10 ** 6 * (self.S - self.beta2c * (4 *self.constant_pi * C / self.wavelength ** 3) / 1000) / (2 *self.constant_pi * C / self.wavelength ** 2) ** 2 # (ps^3/km)
        self.df = ((np.arange(1, self.channel_num + 1) - (self.channel_num + 1) / 2) * self.channel_space)
        self.beta0 = beta2c / 2 * (2 * self.constant_pi * self.df)**2 * 10**-6 + beta3c / 6 * (2 * self.constant_pi * self.df)**3 * 10**-9   #WDM的接收端的cdc与df有关，所以要更新参数
        self.beta1 = beta2c * 2 * self.constant_pi * self.df * 10**-3 + beta3c * (2 * self.constant_pi * self.df)**2 * 10**-6 / 2      
        self.beta2 = beta2c + 2 * self.constant_pi * self.df * beta3c * 10**-3
        self.beta3 = beta3c

        # Check CDC's config parameters
        config = self.cd_com_config
        self.block_upsam = sam_per_sym = self._config_check('sam_per_sym', 2, config)    
        sam_rate = self.sym_rate * sam_per_sym
        fft_num = self.frame_num * self.frame_sym_num * sam_per_sym
        beta = [self.beta0, self.beta1, self.beta2, self.beta3]
        self.cdcom_obj = cdc_design.CDC(mode = config['type'],\
            beta = beta, sam_per_sym = sam_per_sym,\
                sam_rate = sam_rate, fft_num = fft_num,\
                    data_mode = 'tensor', device = self.device)
        cdc_length = self._config_check('cdc_length', 'auto', config['args']) 
        self._config_check('cut_idx', self.cut_idx, config['args'])

        # calculate cdc_length
        if cdc_length == 'auto':
            cdc_length = self.total_len - pre_cd_len
            config['args']['cdc_length'] = cdc_length
        else:
            cdc_length = self._config_check('cdc_length', self.total_len, config['args'])
        # initialization
        self.cdcom_obj.init(**config['args'])

    def __synchronization__(self):
        r"""
            Synchronization initialization function
            check synchronization's config parameters
        """
        config = self.synchron_config
        self._config_check('mode', '4x4', config)    # 4-way correlation
        self._config_check('frame_num', self.frame_num, config)   
        self._config_check('frame_size', self.frame_sym_num, config)   
        self._config_check('data_mode', 'tensor', config)   
        self._config_check('device', self.device, config)  
        self.synchron_obj = synchron_design.synchron(**config)
        synchron_args = self.synchron_config['args']
        # initialization
        self.synchron_obj.init(**synchron_args)
     
    def __mimo__(self):
        r"""
            MIMO initialization function
            check MIMO's config parameters
        """
        self._config_check('mode', 'TD_2x2', self.mimo_config)    # time domain, 2 in 2 out
        self._config_check('lr_optim', 'constant', self.mimo_config)   
        self._config_check('half_taps_num', 15, self.mimo_config)   
        self._config_check('out_sym_num', 1, self.mimo_config)   
        self._config_check('upsam', 2, self.mimo_config)   
        self._config_check('block_size', 1, self.mimo_config)   
        self._config_check('data_mode', 'numpy', self.mimo_config)   
        self._config_check('device', self.device, self.mimo_config)   
        self._config_check('infor_print', self.infor_print, self.mimo_config)   
        self.mimo_obj = adaptive_filter_design.Vanilla(**self.mimo_config)
        
        mimo_args = self.mimo_config['args']
        upsam = self._config_check('upsam', 2, mimo_args) 
        sam_num = self.frame_num * self.frame_sym_num * upsam
        self._config_check('algo_type', 'cma', mimo_args)  # cma, mma, dd_lms
        self._config_check('cma_pretrain', 1, mimo_args)  
        self._config_check('pre_train_iter', 2, mimo_args)  
        self._config_check('train_num', 1024 * 80, mimo_args)  
        self._config_check('train_epoch', 2, mimo_args)  
        self._config_check('lr', 5.0e-4, mimo_args)  
        self._config_check('sam_num', sam_num, mimo_args) 
        self._config_check('tap_init_mode', 1, mimo_args) 
        self._config_check('tap_init_value', 1.0, mimo_args)
        # initialization
        self.mimo_obj.init(**mimo_args)

    def __cpe__(self):
        r"""
            CPE initialization function
            check CPE's config parameters
        """
        self._config_check('mode', 'bps', self.cpe_config)   
        self._config_check('sym_rate', self.sym_rate, self.cpe_config)   
        self._config_check('window_size', 61, self.cpe_config)   
        self._config_check('block_num', 6, self.cpe_config)   
        self._config_check('parallelism', 1, self.cpe_config)   
        self._config_check('infor_print', self.infor_print, self.cpe_config)   
        self._config_check('data_mode', 'tensor', self.cpe_config)   
        self._config_check('device', self.device, self.cpe_config)   
        self.cpe_obj = cpe_design.CPE(**self.cpe_config)
        
        cpe_args = self.cpe_config['args']
 
