from pyexpat import model
import torch
import numpy as np
from scipy.stats import maxwell
from ......fiber_simulation.utils.show_progress import progress_info_return
from .....base.base_optics import Optics_Base_Module
from .....utils.define_freq import calcu_f
from ......fiber_simulation.comm_tools.normalization import norm_dual_pol
import sys
from ......fiber_simulation.channel.channel_trans.fiber.nn.nn_model.model_BiLSTM import BiLSTM
from ......fiber_simulation.channel.channel_trans.fiber.nn.nn_model.model_GAN import Generator
class NN(Optics_Base_Module):
    def __init__(self, len_arr, beta0, beta1, beta2, beta3, pmd, alpha_loss, gamma,\
        fft_num, sam_rate, nn_model, model_path, infor_print = 1, data_mode = 'numpy', *args, **kwargs):
        """[summary]

        Args:
            filter_type (str): Filter type. Defaults to 'brickwall'.
                brikwall_filter     : Brickwall filter
                butter_filter       : Butterworth filter
                bessel_filter       : Bessel filter
                gaussian_filter     : Gaussian filter
                rc_filter           : Raised cosine filter
                rrc_filter          : Root raised cosine filter
            filter_para (dict):
                gpass(float): Butterworth parameter. The maximum loss in the passband (dB). Defaults to 5.
                gstop(float): Butterworth parameter. The minimum attenuation in the stopband (dB). Defaults to 40.
                order(int)  : Bessel parameter. The order of the filter. Defaults to 10.
                beta (float): RC and RRC parameter. The rolloff factor.  Defaults to 0.1.
        Functions:
                get_freq_lpf:
                get_freq_rcfilter:
                filter_in_freq:
                plot:
        Raises:
            AttributeError: If filter_type are not supported.
        """
        super().__init__()
        self.len_arr = len_arr
        self.beta0 = beta0
        self.beta1 = beta1
        self.beta2 = beta2
        self.beta3 = beta3
        self.pmd = pmd
        self.alpha_loss = alpha_loss
        self.gamma = gamma
        self.sam_rate = sam_rate
        self.fft_num = fft_num
        self.nn_model = nn_model
        self.model_path = model_path
        self.infor_print = infor_print
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu') 
        self.Constant_pi = kwargs.get('Constant_pi', np.pi) 

        
    def init(self, pmd_config, *args, **kwargs):
        self.w = 2 * self.Constant_pi * calcu_f(self.fft_num, self.sam_rate * 1e-3)
        phase_factor_freq = self.beta0\
            + self.beta1 * (self.w)\
                + self.beta2 * (self.w ** 2) / 2\
                    + self.beta3 * (self.w ** 3) / 6 
        self.phase_factor_freq = self.data_mode_convert(phase_factor_freq)
        if self.pmd:
            self.pmd_config = pmd_config
            self.dgd_manual = self._config_check('dgd_manual', 1, pmd_config)  
            self.psp_manual = self._config_check('psp_manual', 1, pmd_config)  
            self.pmd_coeff_random = self._config_check('pmd_coeff_random', 0, pmd_config)  
            self.pmd_dz_random = self._config_check('pmd_dz_random', 0, pmd_config)  
            for key in pmd_config:
                self.__dict__[key] = pmd_config[key]

            if self.dgd_manual: 
                # 随机SOP角度和相位变化
                self._config_check('dgd_total', 0.2, pmd_config)
            else:
                self.dgd_rms = np.sqrt(3 * self.Constant_pi / 8 * self.pmd_coeff**2)

            if self.psp_manual:
                #固定SOP的角度相位变化
                self.phi = self._config_check('phi', self.Constant_pi / 4 , pmd_config)
        
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
    
    def forward_pass(self, x, span_len):
        # update the span_len
        self.span_len = span_len
        self.nPol = len(x)
        if self.nPol == 1:
            # Single polarization
            x = self.__nn_scalar__(x)     # 
        elif self.nPol == 2:
            # Dual polarization
            x = self.__nn_matrix__(x)
        return x

    def __com_to_real__(self, x):
        '''
        将复数转换为实数

        这个函数将n点的复信号转换为n点的实信号过实部和虚部的分离和结合

        参数
        ---
        x : list
            输入信号，必须为复数
        
        返回
        ---
        out : list
            输出信号，为实数
        '''
        '''
        Convert complex numbers to real numbers

        This function converts a complex signal at n points to a real signal at n points 
        by separating and combining the real and imaginary parts

        Parameters
        ----------
        x : array_like
            Input array, must be complex
        
        Returns
        -------
        out : 
            Output array. must be real
        '''
        out = []
        if self.data_mode == 'numpy':
            for i in range(len(x)):
                real = x[i].real().reshape(-1,1)
                imag = x[i].imag().reshape(-1,1)
                out.append( np.concatenate((real,imag),1))
        elif self.data_mode == 'tensor' :
            for i in range(len(x)):
                real = torch.real(x[i]).reshape(-1,1)
                imag = torch.imag(x[i]).reshape(-1,1)
                out.append(torch.cat((real,imag),1))
        return out

    def __real_to_com__(self, x):
        '''
        将实数转换为复数

        这个函数将n点的实信号转换为n点的复信号

        参数
        ---
        x : list
            输入信号，必须为实数
        
        返回
        ---
        out : list
            输出信号，为复数
        '''
        '''
        Convert real numbers to complex numbers

        This function converts a real signal at n to a complex signal at n

        Parameters
        ----------
        x : array_like
            Input array, must be real
        
        Returns
        -------
        out : 
            Output array. must be complex
        '''
        out = []
        for i in range(len(x)):
            real = x[i][:,0].reshape(-1)
            imag = x[i][:,1].reshape(-1)
            out.append( real + 1j * imag)
        return out

    def __sliding_window__(self, inputx, inputy, length, step):
        """
        利用滑动窗口处理数据

        通过滑动窗口的方式，取得时间序列信号，并将双偏振的I Q信号转换为一维信号。

        参数
        ---
        inputx : tensor
            x偏振的输入信号
        inputy : tensor
            y偏振的输入信号
        length : int
            滑动窗口的长度
        step : int
            滑动窗口的步长
        
        返回
        ---
        out_data : tensor
            输出信号
        """
        '''
        Use sliding Windows to process data

        By means of sliding window, the time series signal is obtained, and the dual polarization I/Q signal is converted to one-dimensional signal.

        Parameters
        ----------
        inputx : tensor
            the input single of x polarization
        inputy : tensor
            the input single of y polarization
        length : int
            the length of sliding window
        step : int
            the step of sliding window

        Returns
        -------
        out_data : tensor
            the output single
        '''
        out_data = torch.zeros((int((len(inputx)-length+step)/step),length,4)).to(inputx.device)
        j = 0
        for i in range(0, len(inputx)-length+step, step):
            x = inputx[i:(i + length)]
            if len(inputy) != 0:
                y = inputy[i:(i + length)]
            out_data[j]=torch.cat([x,y], 1)
            j += 1
        return out_data

    def __data_process__(self, Inputx, Inputy):
        '''
        将输入信号转换为NN需要的输入格式。

        参数
        ---
        Inputx : tensor 
            x偏振的输入信号
        Inputx : tensor 
            y偏振的输入信号

        返回
        ---
        condition_data : tensor
            输入NN模型的信号
        '''
        '''
        Convert the input signal to the input format required by NN.

        Parameters
        ----------
        inputx : tensor
            the input single of x polarization
        inputy : tensor
            the input single of y polarization

        Returns
        -------
        out_data : tensor
            the output single
        '''
        condition_data = self.__sliding_window__(Inputx, Inputy,self.inSamNum, self.outSamNum)
        if self.nn_model == 'BiLSTM':   
            condition_data = condition_data.reshape(-1,self.time_step,int(self.indim/self.time_step))
        elif self.nn_model == 'GAN':
            condition_data = condition_data.reshape(-1,self.indim)
        return condition_data

    def __trans_nn_matrix__(self, Input, model):
        '''
        利用NN模型传输双偏振信号

        输入信号为[2, N, 2]的双偏振实数信号，其中N为信号点数。将输入信号进行数据预处理后，
        输入NN模型得到输出。当信号长度大于2 ** 13个符号时，为了避免输入信号过长，一次性输入NN时占用太多内存，
        采用分batch输入的形式，降低内存占用。输出信号格式与输入信号相同，均为[2, N, 2]的双偏振实数信号。
        参数
        ---
        Input : list
            输入的双偏振信号，必须为实数
        model : class
            NN模型的实例

        返回
        ---
        out : list
            经过NN模型后输出的双偏振信号，为实数
        '''
        '''
        Transmiting dual polarization signals by using NN model.

        The input signal is [2, N, 2], where N is the number of signal points. After preprocessing the input signal,
        Input NN model produces output. When the signal length is larger than 2 ** 13 symbols, in order to avoid too long input signal, input NN at one time takes up too much memory.
        Batch input is used to reduce memory usage. The output signal format is the same as the input signal, which are [2, N, 2] dual polarization real signal.
        
        Parameters
        ----------
        Input : list
            The input dual polarization signal, must be real
        inputy : tensor
            the input single of y polarization

        Returns
        -------
        out_data : tensor
            the output single
        '''
        # Signal power normalization
        Inputx = self.power_scale*Input[0]
        Inputy = self.power_scale*Input[1]
        # Add ahead and behind symbols
        Addzeroahead = torch.zeros(self.aheadSam, 2).to(self.device)
        ddzerobehind = torch.zeros(self.behindSam, 2).to(self.device)
        testx = torch.cat((Addzeroahead,Inputx),0)
        testx = torch.cat((testx,ddzerobehind),0)
        testy = torch.cat((Addzeroahead,Inputy),0)
        testy = torch.cat((testy,ddzerobehind),0)

        num = Inputx.shape[0]

        if num > (2**13 * self.outSamNum):
            # Input in blocks to avoid too much memory due to too long signal
            genOutNumperBlock = 2**13  # unit symbols
            genSamNumperBlock = genOutNumperBlock*self.outSamNum   # unit samples
            batch_num = int(Inputx.shape[0]/genSamNumperBlock)
            inDimperBlock = genSamNumperBlock + self.aheadSam+self.behindSam
        else:
            # When the signal length is very small, input directly
            genOutNumperBlock = int(num/self.outSamNum)
            genSamNumperBlock = num
            batch_num = 1
            inDimperBlock = genSamNumperBlock + self.aheadSam+self.behindSam

        gen_data = torch.zeros((batch_num,genOutNumperBlock,self.outdim)).to(self.device)
        for i in range(batch_num):
            Condition_data = self.__data_process__(testx[i*genSamNumperBlock:i*genSamNumperBlock+inDimperBlock,:], testy[i*genSamNumperBlock:i*genSamNumperBlock+inDimperBlock,:]) 
            if self.nn_model == 'GAN':
                noise = torch.randn(len(Condition_data), self.noise_dim).to(self.device)
                gen_data[i] = model(Condition_data,noise).detach()
            elif self.nn_model == 'BiLSTM':
                gen_data[i] = model(Condition_data).detach()
                
        gen_data = gen_data.reshape(int(batch_num*genOutNumperBlock), self.outdim)
        # The signal is converted to alternate arrangement of I and Q
        gen_data = gen_data.reshape(int(batch_num*genOutNumperBlock*self.outdim/4), 4)
        rx_sigx_I = gen_data[:,0].reshape(-1,1)
        rx_sigx_Q = gen_data[:,1].reshape(-1,1)
        rx_sigy_I = gen_data[:,2].reshape(-1,1)
        rx_sigy_Q = gen_data[:,3].reshape(-1,1)
        rx_sigx = torch.cat((rx_sigx_I,rx_sigx_Q),1)
        rx_sigy = torch.cat((rx_sigy_I,rx_sigy_Q),1)
        out = [rx_sigx,rx_sigy] 
        return out

    def __cdpmdcm__(self, sigin, length, phase_factor_freq, pmd, data_mode = 'numpy', **kwargs):                                      
        '''
        对光纤中的线性效应进行建模

        利用与分步傅里叶中线性建模相同的算法，将step设置为1，进行NN中线性效应建模。

        参数
        ---
        sigin : list
            输入的双偏振信号，必须为复数
        length : int
            色散长度
        phase_factor_freq : array_like
            色散系数向量
        pmd : bool
            是否存在偏振模色散
        data_mode : str
            数据类型 ndarray or tensor
        **kwargs : dict
            pmd 相关参数
        
        返回
        ---
        sigout : list
            输出的双偏振信号，为复数
        '''
        '''
        Modeling linear effects in optical fibers

        Using the same algorithm as linear modeling in fractional Fourier, step is set to 1 to model linear effect in NN.

        Parameters
        ----------
        sigin : list
            The input dual polarization signal, must be complex
        length : int
            the length of dispersion 
        phase_factor_freq : array_like
            the vector of dispersion coefficient 
        pmd : bool
            whether there is polarization mode dispersion
        data_mode : str
            the type of data   ndarray or tensor
        **kwargs : dict
            related parameters of pmd

        Returns
        -------
        sigout : list
            The output dual polarization signal, a complex number
        '''
        self.l_corr = self.span_len
        device = sigin[0].device
        sig_fft_x = torch.fft.fft(sigin[0]) 
        sig_fft_y = torch.fft.fft(sigin[1]) 
        if data_mode == 'tensor':
            self.sig0 = torch.eye(2, device = device)
            self.sig2 = torch.tensor([[0, 1], [1, 0]], device = device)
            self.sig3i = torch.tensor([[0, 1], [-1, 0]], device = device) # = -j*sig3 = j * [0 ,-j; j 0]
        else:
            sig0 = np.eye(2)
            sig2 = np.array([[0, 1], [1, 0]])
            sig3i = np.array([[0, 1], [-1, 0]]) # = -j*sig3 = j * [0 ,-j; j 0]
            
        if data_mode == 'tensor':
            sig_fft_x = torch.fft.fft(sigin[0])
            sig_fft_y = torch.fft.fft(sigin[1])
            sig_fft_x = torch.fft.fftshift(sig_fft_x)
            sig_fft_y = torch.fft.fftshift(sig_fft_y)
            if pmd:
                # dispersion effect: GVD + PMD
                # plates_idx_list = kwargs.get('plates_idx_list')
                # pmd_dz_arr = kwargs.get('pmd_dz_arr')
                # trunk_list = kwargs.get('trunk_list')
                trunk_list = [self.span_len]
                psp_theta = kwargs.get('psp_theta')
                psp_phi = kwargs.get('psp_phi')
                pmd_arr = kwargs.get('pmd_arr')
                sig0 = np.eye(2)
                sig2 = np.array([[0, 1], [1, 0]])
                sig3i = np.array([[0, 1], [-1, 0]]) # = -j*sig3 = j * [0 ,-j; j 0]
                prop_dz = 0.0
                for i in range(len(trunk_list)):
                    # pmd_dz = trunk_list[plates_idx_list[i]]
                    matr_th = np.cos(psp_theta[i]) * sig0\
                        - np.sin(psp_theta[i]) * sig3i + 0.0j   # orthogonal matrix
                    matr_epsilon = np.cos(psp_phi[i]) * sig0\
                        + 1j * np.sin(psp_phi[i]) * sig2    # orthogonal
                    mat_rot = matr_th @ matr_epsilon # matrix of change of basis over the PSPs. 
                    mat_rot = torch.from_numpy(mat_rot).to(sig_fft_x.device)
                    mat_rot_conj = torch.conj(mat_rot)
                    # Note: Calling A=[x;y] the electric field , we have that matR*D*matR'*A
                    #   is the linear PMD step, where D is the diagonal matrix where the DGD operates.
                    # 1> move onto the PSPs basis
                    uux = mat_rot_conj[0, 0] * sig_fft_x + mat_rot_conj[1, 0] * sig_fft_y
                    uuy = mat_rot_conj[0, 1] * sig_fft_x + mat_rot_conj[1, 1] * sig_fft_y 
                    # 2> apply birefringence, DGD and GVD: all in a diagonal matrix    
                    gvd_beta = phase_factor_freq * length          # common beta factor
                    # pmd_beta = 0.5 * (fiber_para.db1) * fiber_para.dz_pmd[i] / fiber_para.l_corr  # differential beta factor
                    pmd_beta = 0.5 * (pmd_arr[i]) * np.sqrt(length)   # differential beta factor
                    # Note: dzb(k)/brf.lcorr: fraction of DGD within current step dzb(k).
                    uux = torch.exp(- 1j * (gvd_beta + pmd_beta)) * uux 
                    uuy = torch.exp(- 1j * (gvd_beta - pmd_beta)) * uuy 
                    # 3> come back in the original basis
                    sig_fft_x = mat_rot[0, 0] * uux + mat_rot[0, 1] * uuy
                    sig_fft_y = mat_rot[1, 0] * uux + mat_rot[1, 1] * uuy
                #     prop_dz += pmd_dz
                # if prop_dz != length:
                #     raise RuntimeError('Toal PMD dz does not equal to SSFM dz')
            else: #  不加入PMD，只加色散
                sig_fft_x = torch.exp(- 1j * phase_factor_freq * length) * sig_fft_x 
                sig_fft_y = torch.exp(- 1j * phase_factor_freq * length) * sig_fft_y   
            sig_fft_x = torch.fft.ifftshift(sig_fft_x) 
            sig_fft_y = torch.fft.ifftshift(sig_fft_y) 
            sig_x = torch.fft.ifft(sig_fft_x) 
            sig_y = torch.fft.ifft(sig_fft_y) 

        elif data_mode == 'numpy':
            sig_fft_x = np.fft.fft(sigin[0])
            sig_fft_y = np.fft.fft(sigin[1])
            sig_fft_x = np.fft.fftshift(sig_fft_x)
            sig_fft_y = np.fft.fftshift(sig_fft_y)
            if pmd:
                # dispersion effect: GVD + PMD
                plates_idx_list = kwargs.get('plates_idx_list')
                pmd_dz_arr = kwargs.get('pmd_dz_arr')
                # trunk_list = kwargs.get('trunk_list')
                trunk_list = [self.span_len]
                psp_theta = kwargs.get('psp_theta')
                psp_phi = kwargs.get('psp_phi')
                pmd_arr = kwargs.get('pmd_arr')
                sig0 = np.eye(2)
                sig2 = np.array([[0, 1], [1, 0]])
                sig3i = np.array([[0, 1], [-1, 0]]) # = -j*sig3 = j * [0 ,-j; j 0]
                prop_dz = 0.0
                for i in range(len(trunk_list)):
                    pmd_dz = pmd_dz_arr[plates_idx_list[i]]
                    matr_th = np.cos(psp_theta[i]) * sig0\
                        - np.sin(psp_theta[i]) * sig3i  + 0.0j   # orthogonal matrix
                    matr_epsilon = np.cos(psp_phi[i]) * sig0\
                        + 1j * np.sin(psp_phi[i]) * sig2    # orthogonal
                    mat_rot = matr_th @ matr_epsilon # matrix of change of basis over the PSPs. 
                    mat_rot_conj = np.conj(mat_rot)
                    # Note: Calling A=[x;y] the electric field , we have that matR*D*matR'*A
                    #   is the linear PMD step, where D is the diagonal matrix where the DGD operates.
                    # 1> move onto the PSPs basis
                    uux = mat_rot_conj[0, 0] * sig_fft_x + mat_rot_conj[1, 0] * sig_fft_y
                    uuy = mat_rot_conj[0, 1] * sig_fft_x + mat_rot_conj[1, 1] * sig_fft_y 
                    # 2> apply birefringence, DGD and GVD: all in a diagonal matrix    
                    gvd_beta = phase_factor_freq * pmd_dz           # common beta factor
                    # pmd_beta = 0.5 * (fiber_para.db1) * fiber_para.dz_pmd[i] / fiber_para.l_corr  # differential beta factor
                    pmd_beta = 0.5 * (pmd_arr[i]) * np.sqrt(pmd_dz)   # differential beta factor
                    # Note: dzb(k)/brf.lcorr: fraction of DGD within current step dzb(k).
                    uux = np.exp(- 1j * (gvd_beta + pmd_beta)) * uux 
                    uuy = np.exp(- 1j * (gvd_beta - pmd_beta)) * uuy 
                    # 3> come back in the original basis
                    sig_fft_x = mat_rot[0, 0] * uux + mat_rot[0, 1] * uuy
                    sig_fft_y = mat_rot[1, 0] * uux + mat_rot[1, 1] * uuy
                    prop_dz += pmd_dz
                if prop_dz != dz:
                    raise RuntimeError('Toal PMD dz does not equal to SSFM dz')
            else: #  不加入PMD，只加色散
                sig_fft_x = np.exp(- 1j * phase_factor_freq * dz) * sig_fft_x 
                sig_fft_y = np.exp(- 1j * phase_factor_freq * dz) * sig_fft_y   
            sig_x = np.fft.ifft(sig_fft_x) 
            sig_y = np.fft.ifft(sig_fft_y)
        else:
            print("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")  
        sigout = [sig_x, sig_y]
        return sigout

    def __nn_matrix__(self, x):
        '''
        双偏振信号的NN建模

        参数
        ---
        x : list
            双偏振信号，为复数
        
        返回
        ---
        sigout : list
            双偏振信号，为复数
        '''
        '''
        NN modeling of dual polarization signals

        Parameters
        ----------
        sigin : list
            The input dual polarization signal, must be complex
        length : int
            the length of dispersion 
        phase_factor_freq : array_like
            the vector of dispersion coefficient 
        pmd : bool
            whether there is polarization mode dispersion
        data_mode : str
            the type of data   ndarray or tensor
        **kwargs : dict
            related parameters of pmd

        Returns
        -------
        sigout : list
            The output dual polarization signal, a complex number
            
        '''
        for i_p in range(self.nPol):
            x[i_p] = self.data_mode_convert(x[i_p], self.data_mode)
        x = self.__com_to_real__(x)
        self.__nn_model_init__()
        x = self.__trans_nn_matrix__(x ,self.model)
        if self.nn_model == 'BiLSTM':
            if self.pmd:
                self.__pmd_init__()
            self.__lin_func_args_init__()
            x = self.__real_to_com__(x)
            x = self.__cdpmdcm__(x, 80, **self.lin_func_para) 
        elif self.nn_model == 'GAN':
            x = self.__real_to_com__(x)
        x = norm_dual_pol(x[0], x[1], self.sig_power_w_all)
        sigout = x
        return sigout
    
    def __pmd_init__(self):
        if self.dgd_manual:
            pmd_coeff = self.dgd_total / np.sqrt(self.span_len)
        else:
            if self.pmd_coeff_random:
                # sample the pmd coefficient from Maxwellian distribution
                vx = np.random.normal(loc = 0, scale = np.sqrt(self.dgd_rms**2/3))
                vy = np.random.normal(loc = 0, scale = np.sqrt(self.dgd_rms**2/3))
                vz = np.random.normal(loc = 0, scale = np.sqrt(self.dgd_rms**2/3))
                pmd_coeff = np.sqrt(vx**2 + vy**2 + vz**2)
            else:
                pmd_coeff = self.pmd_coeff
        l_corr = self.span_len / self.pmd_trunk_num
        pmd_per_trunk =  pmd_coeff / np.sqrt(self.pmd_trunk_num)  # PMD coefficient per trunk.
        if self.pmd_dz_random:
            self.pmd_dz_arr = np.random.normal(loc = l_corr, scale = l_corr / 5 , size = (self.nplates))
        else:
            self.pmd_dz_arr = np.ones((self.pmd_trunk_num)) * l_corr
        self.pmd_arr = pmd_per_trunk * np.ones(self.pmd_trunk_num)
        if self.psp_manual:
            self.psp_theta   = self.phi * np.ones(self.pmd_trunk_num)
            self.psp_phi     = self.phi * np.ones(self.pmd_trunk_num)
        else:
            self.psp_theta = np.random.rand(self.pmd_trunk_num) * 2 * np.pi - np.pi           # 均匀分布[0-1） azimuth: uniform R.V.
            self.psp_phi   = 0.5 * np.arcsin(np.random.rand(self.pmd_trunk_num) * 2 - 1)     # uniform R.V. over the Poincare sphere

    
    def __lin_func_args_init__(self):
        # Linearity function parameters
        self.lin_func_para = {}
        # self.lin_func_para['dz'] = self.step_size.dz_l
        self.lin_func_para['phase_factor_freq'] = self.phase_factor_freq
        self.lin_func_para['pmd'] = self.pmd
        # self.lin_func_para['pmd_dz_arr'] = self.pmd_dz_arr
        self.lin_func_para['pmd_arr'] = self.pmd_arr
        self.lin_func_para['psp_theta'] = self.psp_theta
        self.lin_func_para['psp_phi'] = self.psp_phi
        # self.lin_func_para['trunk_list'] = self.step_size.trunk_list
        # self.lin_func_para['plates_idx_list'] = self.step_size.plates_idx_list
        self.lin_func_para['data_mode'] = self.data_mode

    def __nn_model_init__(self):
        '''
        NN模型初始化
        '''
        '''
        NN model initialization
        '''
        sys.path.insert(0 ,'/home/Shiminghui/code/new_train_code')
        sys.path.insert(0,'/home/Zengchuyan/code/NN/GAN/new_train_code')
        # load NN model parameters
        dataPara = torch.load(self.model_path+'args.pth', map_location = self.device)
        # load NN model weights
        datapt = torch.load(self.model_path+'checkpoint.pth', map_location = self.device)
        args = dataPara['args']    
        args.model_para = datapt[args.model_name]
        if self.nn_model == 'BiLSTM':
            self.model = BiLSTM(args).to(self.device)
        elif self.nn_model == 'GAN':
            self.model = Generator(args).to(self.device)
        self.model.load_state_dict(args.model_para)
        self.model.eval()
        
        # power normalization factor
        self.power_scale = args.power_scale
        # the power of wdm signal
        self.sig_power_w_all = 1/np.square(self.power_scale)
        # the parameter of NN model
        self.aheadSam = args.aheadSam
        self.behindSam = args.behindSam
        self.indim = args.indim
        self.outdim = args.outdim
        self.outSamNum = args.outSamNum
        self.inSamNum = args.inSamNum
        self.time_step = args.time_step
        if self.nn_model == 'GAN':
            self.noise_dim=args.noise_dim



