import numpy as np
import torch
from IFTS.fiber_simulation.base.base_para import Base_Para

class Simu_Para(Base_Para):
    # Constant
    constant_C = 299792458                # speed of light (m/s)   
    constant_h  = 6.626068e-34             # Plank's Constant
    constant_pi = np.pi
    def __init__(self, rand_seed, simu_configs):
        configs = simu_configs['Simu_Para']
        self.rand_seed = rand_seed
        # Path
        self._check("code_path", configs = configs)
        self._check("tx_data_path", configs = configs)
        self._check("rx_data_path", configs = configs)
        # AWG and oscilloscope parameters
        self._check("awg_memory_depth",  2**18, configs)    # 2**18 for Fujistu  
        self._check("tx_sam_rate",  80, configs)    # 80 Gsam/sec for Fujistu  
        self._check("rx_sam_rate",  100, configs)   # 100 Gsam/sec for Tec
        # data load information
        self._check("tx_load_data",  0, configs)   # 
        self._check("rx_load_data",  0, configs)   # 
        # expriments information
        self._check("sym_rate",  40, configs)   # GHz
        self._check("bits_per_sym", 4, configs)
        self._check("nPol",         2, configs)
        self._check("sig_power_dbm",0, configs)
        self._check("channel_num",0, configs)       
        # control variable for sig, tx, rx
        # sig
        self._check("modulation",   1, configs)
        self._check("demodulation", 1, configs)
        # parameters for tx
        self._check("pulse_shaping",1, configs)
        # parameters for channel
        self._check("dac",          0, configs)
        self._check("transmitter",  0, configs)
        self._check("wss",          1, configs)
        self._check("channel_type", 0, configs) # 0:Experiments, 1: Simu_fiber, 2: Simu_AWGN
        self._check("span_num",     0, configs)
        self._check("channel_space",50, configs)
        self._check("load_len",     1, configs)
        self._check("span_len",     80, configs)
        self._check("total_len",    800, configs)
        self._check("cut_idx",  int((self.channel_num-1)/2), configs)   # index of channel under test 
        self._check("receiver",     1, configs)
        self._check("adc",          0, configs)
        # parameters for rx    
        self._check("lpf", 1, configs)
        self._check("cdcom",        1, configs)
        self._check("mimo",         1, configs)
        self._check("synchronization", 1, configs)
        self._check("cpe",          1, configs)
        # path
        self._check("result_path", configs = configs)
        self._check("figure_path", configs = configs)
        self._check("save_data_path", configs = configs)
        
        # Parameters which are need to be calculated
        self._check("ch_sam_rate",  self.channel_num * self.sym_rate * 4, configs)
        self._check("ch_fft_num",  int(self.ch_sam_rate*self.awg_memory_depth/self.tx_sam_rate), configs)
        self._check("rx_fft_num",  int(self.rx_sam_rate*self.awg_memory_depth/self.tx_sam_rate), configs)
        self._check("class_num",    2 ** self.bits_per_sym, configs)
        self._check("up_sam",  self.tx_sam_rate / self.sym_rate, configs)
        bit_num_per_pol = (int(np.ceil(self.awg_memory_depth / self.up_sam))) * self.bits_per_sym    # AWG发送的最大序列是2^18，实际符号数目为 10G 2^15, 20G 2^16, 40G 2^17
        self._check("bit_num_per_pol", bit_num_per_pol, configs)
        self._check("modulation_list",  configs = configs)
        self._check("modulation_type", self.modulation_list[str(self.bits_per_sym)], configs)
        
        # paramters for simulation
        self._check("data_mode", 'fine_tune', configs)  # numpy, tensor, fine_tune 
        self._check("caclu_with_gpu", 1, configs)  # numpy, tensor, fine_tune 
        self.device = torch.device(self._check("device", "cuda:3", configs) if (torch.cuda.is_available() and self.caclu_with_gpu > 0) else "cpu")
        self._check("fig_plot_arr",     [1, 0, 1, 1], configs)  # tx ch rx 
        self._check("infor_print_arr",  [1, 0, 1, 1], configs)  # tx ch rx
        self._check("save_data_arr",    [0, 0, 0, 0], configs)  # tx ch rx
        self._check("q_factor",     np.zeros(self.nPol + 1))  # tx ch rx
        torch.set_default_tensor_type(torch.DoubleTensor)
        self.save_configs(self.result_path + 'Simulation_Para.yaml', configs)
    
    def save_data_func(self, x, name, ** kwargs):
        path = kwargs.get('path', self.save_data_path)
        if type(x) == torch.Tensor:
            temp = x.cpu().numpy()
        else:
            temp = x
        np.savez_compressed(path + name + '.npz', data = temp)
    
    def save_data_func_npol(self, x, name, ** kwargs):
        path = kwargs.get('path', self.save_data_path)
        data_x = x[0]
        data_y = x[1]
        if type(data_x) == torch.Tensor:
            data_x = x[0].cpu().numpy()
            data_y = x[1].cpu().numpy()            
        np.savez_compressed(path + name + '.npz', data_x = data_x, data_y = data_y)
    
    def print_para(self, information, ** kwargs):
        print('Data Parameters\n-----------------------------------------------', file = information)
        print('\n'.join(['%s:%s' % item for item in self.__dict__.items()]), file = information)  
    
    def get_data_name(self, csv = False, idx = None):
        if self.is_ase:
            self.data_name = 'ASE'
            if csv:
                return self.data_name + '_Ch1.csv', self.data_name + '_Ch2.csv', \
                    self.data_name + '_Ch3.csv', self.data_name + '_Ch4.csv'
            else:
                return self.data_name + '.npz', self.tx_data_name + '.mat'
        else:
            self.data_name =  str(self.sym_rate) + 'G'\
                + self.modulation_type
            self.tx_data_name = self.tx_data_path + '/tx_information/' + self.modulation_type
            if self.is_b2b == 0:
                self.data_name = self.data_name + '_span' + str(self.span_num)
            if self.is_bidi:
                self.data_name = self.data_name + '_BIDI'
            if idx is None:
                if csv:
                    return self.data_name + '_Ch1.csv', self.data_name + '_Ch2.csv', \
                        self.data_name + '_Ch3.csv', self.data_name + '_Ch4.csv', \
                            self.tx_data_name + '.mat', self.tx_data_path + self.modulation_type + '_map.mat'
                else:
                    return self.data_name + '.npz', self.tx_data_name + '.mat',\
                        self.modulation_type + '.mat'
            else:
                if csv:
                    self.data_name = self.data_name + '_' + str(idx)
                    return self.data_name + '_Ch1.csv', self.data_name + '_Ch2.csv', \
                        self.data_name + '_Ch3.csv', self.data_name + '_Ch4.csv', \
                            self.tx_data_name + '.mat', self.tx_data_path + self.modulation_type + '_map.mat'
                else:
                    self.data_name = self.data_name + '_' + str(idx) + '.npz'
                    return self.data_name, self.tx_data_path + 'tx_data.npz',\
                        self.tx_data_path + self.modulation_type + '_map.npz'