import numpy as np
from IFTS.simulation_main.modul_para.simulation_para import Simu_Para
from IFTS.fiber_simulation.sig.modulation import demod_design

class Sig_Para(Simu_Para):
    '''
    Sig_para class inherits from Simu_para class. The properties of Sig_para class 
    include all properties related to simulation parameter settings in the parent class 
    Simu_para and parameters related to signal transmission. This class is used to configure 
    signal related parameters and signal modulation and demodulation related parameters in 
    the simulation process.
    '''
    def __init__(self, rand_seed, simu_configs):
        '''
        Initialization function of class Sig_para. This function reads from the config configuration 
        file and sets the parameters related to signal transmission. Parameters can be set, such as 
        the number of bits per symbol, the number of padding symbols of the signal, signal power, bit 
        rate, polarization number, data format, etc.
        Parameters:
        rand_seed: int
        The random number seed used to simulate some random effects, such as white noise and ASE noise of EDFA. 
        simu_configs: yaml configuration file
        Basic simulation parameter configuration
        Returns:
        This function has no return value
        '''
        super().__init__(rand_seed, simu_configs)   
        configs = simu_configs['Sig_Para']
        self._check("front_sym_num",  2048, configs)
        self.padding_num    = 2 * self.front_sym_num                                            # add some symbols in the end of symbols
        self.bit_num        = self._check("nPol") * self._check("bit_num_per_pol")
        self.sym_num        = int(self.bit_num / self._check("bits_per_sym"))
        self.sym_num_wo_padding = int(self.sym_num / self._check("nPol")) - self.padding_num 
        self.sym_num_per_pol    =  int(self.sym_num / self._check("nPol"))
        self.sig_power_w    = 10 ** ((self._check("sig_power_dbm") - 30 ) / 10)
        self.bit_rate       = self._check("sym_rate") * self._check("bits_per_sym")
        self.infor_print    = self._check('infor_print_arr', np.ones(3))[0]
        self.fig_plot       = self._check('fig_plot_arr', np.ones(3))[0]
        self.save_data      = self._check('save_data_arr', np.zeros(3))[0]
        
        if self._check('data_mode', 'fine_tune') == 'fine_tune':
            self.sig_data_mode = 'numpy'                
        else:
            self.sig_data_mode = self.data_mode  # numpy, tensor
            
        if self._check("bit_load",  0, configs):
            """
            bit_num = 2^27 = 134217728
            Merbit: Mersenne twister
            Phibit: philox
            Thrbit: threefry
            Combit: combRecursive
            HybridBit: mix
            """
            self._check("random_type", 'Merbit', configs)
            self._check("bit_load_path",  '/home/ospan/code/Niu/data/random_num.mat', configs)
        else:
            self._check("random_type", 'MT19937', configs)
            self._check("rand_seed", -1, configs)
        
        if self._check('coding', 0, configs):
            self.__coding__()
        
        if self._check('modulation', 1, configs):
            self.modu = 'qam'
            self.__modulation__()
            
        if self._check('demodulation', 1, configs):
            self._check('demod_config', configs = configs)
            self.__demodulation__()


        self.save_configs(self.result_path + 'Sig_Para.yaml', configs)
        
    def __modulation__(self):
        '''
        Configure parameters related to signal modulation. The current version can choose whether 
        to perform geometric shaping
        '''
        self.geometric_shaping  = 0 

    def __demodulation__(self):
        '''
        Configure parameters related to signal demodulation. Read the parameters from the configuration 
        file and create the demod class object self.demod_obj
        '''
        self._config_check('mode', 'llr', self.demod_config)
        self._config_check('order', self.class_num, self.demod_config)
        self._config_check('data_mode', None, self.demod_config)
        self._config_check('device', self.device, self.demod_config)
        self.demod_obj = demod_design.QAMDemod(**self.demod_config)
   