import numpy as np
import torch
import warnings
from ..modul_para.signal_para import Sig_Para
from ...fiber_simulation.comm_tools.filter_design import filter_design

class Tx_Para(Sig_Para):
    def __init__(self, rand_seed, simu_configs):
        super().__init__(rand_seed, simu_configs)
        configs = simu_configs['Tx_Para']
        self.sam_rate       = self._check("tx_sam_rate")       # GHz
        self.dt             = 1000 / self.sam_rate             # ps
        self.fft_num        = self._check('awg_memory_depth', 2**18, configs)
        self.infor_print    = self._check('infor_print_arr', np.ones(3))[1]
        self.fig_plot       = self._check('fig_plot_arr', np.ones(3))[1]
        self.save_data      = self._check('save_data_arr', np.zeros(3))[1]
        if self._check('data_mode', 'fine_tune') == 'fine_tune':
            self.tx_data_mode = 'tensor'                # numpy, tensor
        else:
            self.tx_data_mode = self.data_mode         # numpy, tensor

        if self._check('pulse_shaping', 1, configs):
            self._check('pulse_shaping_config', configs = configs)
            self.__pulse_shaping__()
        else:
            warnings.warn("No pulse shaping !")
        
        self.save_configs(self.result_path + 'Tx_Para.yaml', configs)
        
    
    def __pulse_shaping__(self):
        # brickwall, bessel, butter, gaussian, rc, rrc
        config = self.pulse_shaping_config
        self.ps_obj = filter_design.Digital_filter_para_freq(filter_type = config['type'],\
                data_mode = self.tx_data_mode, device = self.device)
        upsam = self._config_check('upsam', 2, config['args'])  # default 2 up-sample rate
        self._config_check('sam_rate', self.sym_rate * upsam, config['args'])
        self._config_check('cut_off', self.sym_rate, config['args'])
        if self.ps_obj.filter_type == 'rrc' or self.ps_obj.filter_type == 'rc':
            self._config_check('beta', 0.1, config['args'])
            self.signal_bw = self.sym_rate * (1 + config['args']['beta']) / 2
        self.ps_obj.init(**config['args'])