from pyexpat import model
import torch
import numpy as np
from scipy.stats import maxwell
from IFTS.fiber_simulation.utils.show_progress import progress_info_return
from IFTS.fiber_simulation.base.base_optics import Optics_Base_Module
from IFTS.fiber_simulation.utils.define_freq import calcu_f
from IFTS.fiber_simulation.comm_tools.normalization import norm_dual_pol
import sys
from IFTS.fiber_simulation.channel.channel_trans.fiber.nn.models.model_GAN import Generator
class NN(Optics_Base_Module):
    """
    NN class is used to realize simulation transmission of fiber channel by neural network (NN). 

    By replacing the traditional distributed Fourier method (SSFM) with neural networks, 
    the simulation of fiber channel can be accelerated. The function of class method corresponds to each 
    specific step of channel simulation using NN, which is used to preprocess the input data of NN, 
    use NN for nonlinear modeling and use traditional methods for linear modeling. One-span long 
    transmission will be completed every time a NN object is called.
    """
    def __init__(self, len_arr, beta0, beta1, beta2, beta3, pmd, alpha_loss, gamma,\
        fft_num, sam_rate, nn_model, model_path, infor_print = 1, data_mode = 'numpy', *args, **kwargs):
        """
        Pass in parameters and initialize the object.

        The whole initialization process is completed in two steps using two init methods. 
        The basic __init__ method obtains basic parameters such as the sample rate and 
        the GVD parameters. The init method further completes detailed configuration. 
        

        Parameters
        ----------
        len_arr: ndarray
            The length of each span (km).
        beta0, beta1, beta2, beta3 : float
            Chromatic dispersion coefficients of zero-order to the third order(ps^2/km).
        pmd : int,{0,1}
            Determine whether to calculate PMD in simulation.
        alpha_loss : float
            The fiber loss parameter(Np/km).
        gamma:float
            The nonlinear coefficient(km^-1.W^-1).
        fft_num:int
            The number of FFT points,which normally equals to the length of input signal sequence.
        sam_rate:int
            The channel sample rate,which equals to channel numbers*symbol rates*4
        nn_model:str
            The NN model category is used for modeling, typically GAN
        model_path:str
            The Path to store nn model parameters and weights
        infor_print:int,{0,1},optional
            Determine whether to print the progress of performing the algorithm.Default:1
        data_mode:str,{'numpy','tensor'},optional
            The data type used in operation.Default:'numpy'.
        **kwargs : dict
               Other arguments.
               {'device','constant_pi','constant_h'} 
        
        """
        super().__init__()
        self.len_arr = len_arr
        self.beta0 = beta0
        self.beta1 = beta1
        self.beta2 = beta2
        self.beta3 = beta3
        self.pmd = pmd
        self.alpha_loss = alpha_loss
        self.gamma = gamma
        self.sam_rate = sam_rate
        self.fft_num = fft_num
        self.nn_model = nn_model
        self.model_path = model_path
        self.infor_print = infor_print
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu') 
        self.Constant_pi = kwargs.get('Constant_pi', np.pi)
        self.package_type = kwargs.get('package_type', 0)
        self.package_num = kwargs.get('package_num', 0)
        self.dtype = torch.float32
        self.__nn_model_init__()

        
    def init(self, pmd_config, *args, **kwargs):
        self.w = 2 * self.Constant_pi * calcu_f(self.fft_num, self.sam_rate * 1e-3)
        phase_factor_freq = self.beta0\
            + self.beta1 * (self.w)\
                + self.beta2 * (self.w ** 2) / 2\
                    + self.beta3 * (self.w ** 3) / 6 
        self.phase_factor_freq = self.data_mode_convert(phase_factor_freq)
        if self.pmd:
            self.pmd_config = pmd_config
            self.dgd_manual = self._config_check('dgd_manual', 1, pmd_config)  
            self.psp_manual = self._config_check('psp_manual', 1, pmd_config)  
            self.pmd_coeff_random = self._config_check('pmd_coeff_random', 0, pmd_config)  
            self.pmd_dz_random = self._config_check('pmd_dz_random', 0, pmd_config)  
            for key in pmd_config:
                self.__dict__[key] = pmd_config[key]

            if self.dgd_manual: 
                # 随机SOP角度和相位变化
                self._config_check('dgd_total', 0.2, pmd_config)
            else:
                self.dgd_rms = np.sqrt(3 * self.Constant_pi / 8 * self.pmd_coeff**2)

            if self.psp_manual:
                #固定SOP的角度相位变化
                self.phi = self._config_check('phi', self.Constant_pi / 4 , pmd_config)
        
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
    
    def forward_pass(self, x, span_len):
        # update the span_len
        self.span_len = span_len
        self.nPol = len(x)
        if self.nPol == 1:
            # Single polarization
            x = self.__nn_scalar__(x)      
        elif self.nPol == 2:
            # Dual polarization
            x = self.__nn_matrix__(x)
        return x

    def __com_to_real__(self, x):
        '''
        Convert complex numbers to real numbers

        This function converts a complex signal at n points to a real signal at n points 
        by separating and combining the real and imaginary parts

        Parameters
        ----------
        x : array_like
            Input array, must be complex
        
        Returns
        -------
        out : 
            Output array. must be real
        '''
        out = []
        if self.data_mode == 'numpy':
            for i in range(len(x)):
                real = x[i].real().reshape(-1,1)
                imag = x[i].imag().reshape(-1,1)
                out.append( np.concatenate((real,imag),1))
        elif self.data_mode == 'tensor' :
            for i in range(len(x)):
                real = torch.real(x[i]).reshape(-1,1)
                imag = torch.imag(x[i]).reshape(-1,1)
                out.append(torch.cat((real,imag),1))
        return out

    def __real_to_com__(self, x):
        '''
        Convert real numbers to complex numbers

        This function converts a real signal at n to a complex signal at n

        Parameters
        ----------
        x : array_like
            Input array, must be real
        
        Returns
        -------
        out : 
            Output array. must be complex
        '''
        out = []
        for i in range(len(x)):
            real = x[i][:,0].reshape(-1)
            imag = x[i][:,1].reshape(-1)
            out.append( real + 1j * imag)
        return out

    def __data_process__(self, Inputx, Inputy):
        '''
        Convert the input signal to the input format required by NN.

        Parameters
        ----------
        inputx : tensor
            the input single of x polarization
        inputy : tensor
            the input single of y polarization

        Returns
        -------
        out_data : tensor
            the output single
        '''
        condition_data = torch.cat([Inputx[self.sliding_window,:],Inputy[self.sliding_window,:]],2)
        if self.nn_model == 'GAN':
            condition_data = condition_data.reshape(-1,self.indim)
        return condition_data

    def __trans_nn_matrix__(self, Input, model):
        '''
        Transmiting dual polarization signals by using NN model.

        The input signal is [2, N, 2], where N is the number of signal points. After preprocessing the input signal,
        Input NN model produces output. When the signal length is larger than 2 ** 13 symbols, in order to avoid too long input signal, input NN at one time takes up too much memory.
        Batch input is used to reduce memory usage. The output signal format is the same as the input signal, which are [2, N, 2] dual polarization real signal.
        
        Parameters
        ----------
        Input : list
            The input dual polarization signal, must be real
        inputy : tensor
            the input single of y polarization

        Returns
        -------
        out_data : tensor
            the output single
        '''
        # Signal power normalization
        Inputx = self.power_scale*Input[0]
        Inputy = self.power_scale*Input[1]
        # Add ahead and behind symbols
        Addzeroahead = torch.zeros(self.aheadSam, 2, dtype = self.dtype).to(self.device)
        ddzerobehind = torch.zeros(self.behindSam, 2, dtype = self.dtype).to(self.device)
        testx = torch.cat((Addzeroahead,Inputx),0)
        testx = torch.cat((testx,ddzerobehind),0)
        testy = torch.cat((Addzeroahead,Inputy),0)
        testy = torch.cat((testy,ddzerobehind),0)

        num = Inputx.shape[0]

        if num > (2**13 * self.outSamNum):
            # Input in blocks to avoid too much memory due to too long signal
            genOutNumperBlock = 2**13  # unit symbols
            genSamNumperBlock = genOutNumperBlock*self.outSamNum   # unit samples
            batch_num = int(Inputx.shape[0]/genSamNumperBlock)
            inDimperBlock = genSamNumperBlock + self.aheadSam+self.behindSam
        else:
            # When the signal length is very small, input directly
            genOutNumperBlock = int(num/self.outSamNum)
            genSamNumperBlock = num
            batch_num = 1
            inDimperBlock = genSamNumperBlock + self.aheadSam+self.behindSam

        for i in range(batch_num):
            Condition_data = self.__data_process__(testx[i*genSamNumperBlock:i*genSamNumperBlock+inDimperBlock,:], testy[i*genSamNumperBlock:i*genSamNumperBlock+inDimperBlock,:]) 
            if self.nn_model == 'GAN':    
                gen_data_tmp = model(Condition_data,self.noise[i]).detach()
            if i==0:
                gen_data = gen_data_tmp
            else:
                gen_data = torch.cat((gen_data,gen_data_tmp),0)    
        gen_data = gen_data.reshape(int(batch_num*genOutNumperBlock), self.outdim)
        # The signal is converted to alternate arrangement of I and Q
        gen_data = gen_data.reshape(int(batch_num*genOutNumperBlock*self.outdim/4), 4)
        rx_sigx_I = gen_data[:,0].reshape(-1,1)
        rx_sigx_Q = gen_data[:,1].reshape(-1,1)
        rx_sigy_I = gen_data[:,2].reshape(-1,1)
        rx_sigy_Q = gen_data[:,3].reshape(-1,1)
        rx_sigx = torch.cat((rx_sigx_I,rx_sigx_Q),1)
        rx_sigy = torch.cat((rx_sigy_I,rx_sigy_Q),1)
        out = [rx_sigx,rx_sigy] 
        return out


    def __nn_matrix__(self, x):
        '''
        NN modeling of dual polarization signals

        Parameters
        ----------
        sigin : list
            The input dual polarization signal, must be complex
        length : int
            the length of dispersion 
        phase_factor_freq : array_like
            the vector of dispersion coefficient 
        pmd : bool
            whether there is polarization mode dispersion
        data_mode : str
            the type of data   ndarray or tensor
        **kwargs : dict
            related parameters of pmd

        Returns
        -------
        sigout : list
            The output dual polarization signal, a complex number
            
        '''
        for i_p in range(self.nPol):
            x[i_p] = self.data_mode_convert(x[i_p], self.data_mode)
            x[i_p] = x[i_p].type(torch.complex64)
        x = self.__com_to_real__(x)
        x = self.__trans_nn_matrix__(x ,self.model)
        if self.nn_model == 'GAN':
            x = self.__real_to_com__(x)
        x = norm_dual_pol(x[0], x[1], self.sig_power_w_all)
        sigout = x
        return sigout
    
    def __data_process_init__(self):
        '''
        Calulate sliding window index for data processing

        Parameters
        ----------
        
        Returns
        -------
        sigout : list
            The output dual polarization signal, a complex number
        '''
        # Input in blocks to avoid too much memory cost due to too long signal
        if self.fft_num> (2**13 * self.outSamNum):           
            self.genOutNumperBlock = 2**13  # unit:symbols
            self.genSamNumperBlock = self.genOutNumperBlock*self.outSamNum # unit:samples
            self.batch_num = int(self.fft_num/self.genSamNumperBlock)
        else:
            # Input directly when the signal length is very small
            self.genOutNumperBlock = int(self.fft_num/self.outSamNum)
            self.genSamNumperBlock = self.fft_num
            self.batch_num = 1
        self.inDimperBlock = self.genSamNumperBlock + self.aheadSam+self.behindSam

        # calculate wiindow matrix
        window_num= int((self.inDimperBlock-self.inSamNum+self.outSamNum)/self.outSamNum)
        window_len= self.inSamNum
        window = torch.arange(window_len).to(self.device)
        first_index = (torch.arange(0, self.inDimperBlock-self.inSamNum+self.outSamNum, self.outSamNum)).reshape(-1,1).to(self.device)
        self.sliding_window= (window.repeat(window_num,1) + first_index.repeat(1,window_len)).to(self.device)
        self.noise = torch.randn(self.batch_num,self.genOutNumperBlock, self.noise_dim,dtype = self.dtype).to(self.device)
    
    def __nn_model_init__(self):
        '''
        NN model initialization
        '''
        nn_loacl_path = sys.path[self.package_num]
        if self.package_type:
            sys.path.insert(0, nn_loacl_path + '/IFTS/fiber_simulation/channel/channel_trans/fiber/nn')
        else:
            sys.path.insert(0,'./IFTS/fiber_simulation/channel/channel_trans/fiber/nn')
        # load NN model parameters
        dataPara = torch.load(self.model_path+'args.pth', map_location = self.device)
        # load NN model weights
        datapt = torch.load(self.model_path+'checkpoint.pth', map_location = self.device)
        args = dataPara['args']    
        args.model_para = datapt[args.model_name]
        if self.nn_model == 'GAN':
            self.model = Generator(args).to(self.device).type(self.dtype)
        self.model.load_state_dict(args.model_para)
        self.model.eval()
        
        # power normalization factor
        self.power_scale = args.power_scale
        # the power of wdm signal
        self.sig_power_w_all = 1/np.square(self.power_scale)
        # the parameter of NN model
        self.aheadSam = args.aheadSam
        self.behindSam = args.behindSam
        self.indim = args.indim
        self.outdim = args.outdim
        self.outSamNum = args.outSamNum
        self.inSamNum = args.inSamNum
        self.time_step = args.time_step
        if self.nn_model == 'GAN':
            self.noise_dim=args.noise_dim
        self.__data_process_init__()





