import torch
import numpy as np
from fiber_simulation.utils.show_progress import progress_info_return
from ...base.base_optics import Optics_Base_Module
from ...comm_tools.filter_design import filter_func as f
from ...comm_tools.normalization import norm_1d

"""
包含 'wdm_complexing' and 'wdm_demultplexing'
"""
class WSS(Optics_Base_Module):
    def __init__(self, filter_type, args, wdm_data_mode = 'numpy',\
            dwdm_data_mode = 'tensor', **kwargs):
        """
        This is a class implementing wavelength-division multiplexing and demultiplexing.

        WSS is used to implement wavelength-division multiplexing(WDM) and demultiplexing 
        in optical fiber communication system. The methods of WSS are used to config the filter 
        and to implement signal operation.

        ...

        Parameters
        ----------
        filter_type: str
            The type of the filter.Options are listed below:
            {'brickwall_filter','gaussian_filter','wss_filter' 
             'butter_filter','rc_filter','rrc_filter'}  
        filter_config: dict
            The cofiguration information of the filter.
            {'type' : str
             Equals to filter_type.
            'args' : dict
             The arguments needed according to the set type.
            }
            Refer to comml_tools.filter_design for the filter types and 
            correspoding keys in 'args'.
        wdm_data_mode : {'numpy','tensor'}, optional, default='numpy'
            The opration data type in wdm_multiplexing.
        dwdm_data_mode : {'numpy','tensor'}, optional, default='tensor'
            The opration data type in wdm_demultiplexing.
        **kwargs: dict
            Other arguments needed.Device name here.
        
        """
        self.filter_config = {'type': filter_type, 'args':{}}
        self.data_type = {'numpy': np.ndarray, 'tensor': torch.Tensor}
        self.data_can_reuse = 0
        self.filter_type = filter_type
        self.wdm_data_mode = wdm_data_mode
        self.dwdm_data_mode = dwdm_data_mode
        self.device = kwargs.get('device', 'cpu') 
        self.init(**args)
        
    def init(self, sig_p, channel_num, cut_idx, channel_space, fft_num, sam_rate, infor_print=1, **kwargs):
        """Config the arguments specially for WDM process.
        
        This init method is called in __init__ and is used to config 
        arguments for filter configuration and signal opration.

        Parameters
        ----------
        sig_p : float
            The set power(dBm). The power of the filtered signal will be normalized
            to this value .
        channel_num : int
            The number of channels in WDM system.
        cut_idx : int
            Index of the selected channel, defaults to the channel at the center frequency.
        channel_space:float
            The bandwidth of each channel (GHz).
        fft_num:int
            The number of points in discrete-time Fourier transform of the signal, 
            which equals the number of signal samples.
        sam_rate: 
            The channel sample rate, which equals to 4*channel_num *symbol rate(GHz).
        infor_print:int,{0,1},optional.
            Whether to print the progress of the WDM process. Defaults to 1.
        
        Raises
        ------
        AttributeError
            If the set filter_type is not supported.  
        """
        self.channel_num = channel_num
        self.cut_idx = cut_idx
        self.fft_num = fft_num
        self.cut_off = channel_space
        self.sam_rate = sam_rate
        self.channel_space = channel_space
        self.sig_p = sig_p  # power unit:dBm
        self.simu_time = np.arange(0, fft_num) / sam_rate
        self.infor_print = infor_print
        para_args = self.filter_config['args']
        for key in kwargs:
            para_args[key] = kwargs[key]
        if self.filter_type == 'butter_filter':
            self.filter_config['args']['gpass'] = para_args.get('gpass', 5)
            self.filter_config['args']['gstop'] = para_args.get('gstop', 40)
        elif self.filter_type == 'bessel_filter':
            self.filter_config['args']['order'] = para_args.get('order', 10)
        elif self.filter_type == 'rc_filter' or self.filter_type == 'rrc_filter':
            self.filter_config['args']['beta'] = para_args.get('beta', 10)
        elif self.filter_type == 'brickwall_filter' or self.filter_type == 'gaussian_filter':
            pass
        elif self.filter_type == 'wss_filter':
            self.filter_config['args']['bandwidth'] = para_args.get('bandwidth', para_args['bandwidth'])
            self.filter_config['args']['steepness'] = para_args.get('steepness', para_args['steepness'] * 0.1)
        else:
            raise AttributeError('Such filter is not supported')   
        
    def get_transfer_func(self, *args, **kwargs):
        """Get the transfer function of the filter."""
        self.filter_func = self.init_func(self.filter_config, f, *args, **kwargs)
          
    def filter_in_freq(self, x, h):
        """Filter the signal in frequency domain.

        Parameters
        ----------
        x : list
          The signal to be filtered.
        h : ndarray.
          The transfer function of the filter.
        
        Returns
        -------
        y:list
        Filtered signal samples in the time domain.

        """
        if type(x) == torch.Tensor:
            x_fft = torch.fft.fftshift(torch.fft.fft(x))
            y_fft = x_fft * h
            y = torch.fft.ifft(torch.fft.ifftshift(y_fft))
        else:
            x_fft = np.fft.fftshift(np.fft.fft(x))
            y_fft = x_fft * h
            y = np.fft.ifft(np.fft.ifftshift(y_fft))
        
        return y


    def wdm_multiplexing(self, sigin):
        """Implement the wavelength division multiplexing process.
       
        The wrapper function for multiplexing called outside. The signal operation 
        process is as follows: first filter the signal,then normalize the signal power 
        to the set value,lastly multiply the filtered signal with phase arrays to move 
        signals onto different channels.
         
        Parameters
        ----------
        sigin:list
        The input signal.

        Returns
        -------
        sigout:list
        The multiplexed signal. 
     
        """
        
        def wdm(x, channel_num, infor_print):
            @progress_info_return(total_num = channel_num,\
                infor_print = infor_print, discription = ' WDM')
            def main(sigin, **kwargs):
                pbar = kwargs.get('pbar', 0)
                i_total = 0
                sigout = []
                nPol = len(sigin[0])
                if nPol == 1:
                    wdm_sig = 0.0 + 1j
                    for i_c in range(len(sigin)):
                        phase_array = 2 * np.pi * self.df_matrix[i_c] * self.simu_time 
                        if self.wdm_data_mode == 'tensor':
                            phase_array = torch.exp(1j * phase_array)
                        else:
                            phase_array  = np.exp(1j * phase_array)
                        # covert the data mode
                        sig_x = self.data_mode_convert(sigin[i_c][0], self.wdm_data_mode)
                        # WSS filter
                        sig_x = self.filter_in_freq(sig_x, self.h)
                        sig_x = norm_1d(sig_x, sig_power_w)
                        # WDM
                        wdm_sig += sig_x * phase_array
                        i_total += 1
                        if type(pbar) != type(0):
                            pbar.update(1)
                    sigout.append(wdm_sig)
                else:
                    wdm_sig_x = 0.0 + 0.0j
                    wdm_sig_y = 0.0 + 0.0j
                    for i_c in range(len(sigin)):
                        phase_array = 2 * np.pi * self.df_matrix[i_c] * self.simu_time 
                        if self.wdm_data_mode == 'tensor':
                            phase_array = torch.exp(1j * phase_array)
                        else:
                            phase_array  = np.exp(1j * phase_array)
                        # covert the data mode
                        sig_x = self.data_mode_convert(sigin[i_c][0], self.wdm_data_mode)
                        sig_y = self.data_mode_convert(sigin[i_c][1], self.wdm_data_mode)
                        # WSS filter
                        sig_x = self.filter_in_freq(sig_x, self.h)
                        sig_y = self.filter_in_freq(sig_y, self.h)
                        # Amplifier
                        sig_x = norm_1d(sig_x, sig_power_w / 2)
                        sig_y = norm_1d(sig_y, sig_power_w / 2)
                        # WDM
                        wdm_sig_x += sig_x * phase_array
                        wdm_sig_y += sig_y * phase_array
                        i_total += 1
                        if type(pbar) != type(0):
                            pbar.update(1)
                    sigout = [wdm_sig_x, wdm_sig_y]
                return sigout
            x = main(x)
            return x
        
        sig_power_w = (10 ** ((self.sig_p - 30 ) / 10))
        sigout = []
        if self.channel_num % 2 == 0:
            self.df_matrix = self.channel_space *np.linspace(-(self.channel_num/2-1), self.channel_num/2, self.channel_num)
        else:
            self.df_matrix = self.channel_space *np.linspace(-(self.channel_num-1)/2, (self.channel_num-1)/2, self.channel_num)
        # prepare the WSS filter
        self.get_transfer_func(fft_num = self.fft_num, sam_rate = self.sam_rate, cut_off = self.cut_off)
        self.h = self.filter_func()
        # convert the data mode
        self.simu_time = self.data_mode_convert(self.simu_time, self.wdm_data_mode)
        self.df_matrix = self.data_mode_convert(self.df_matrix, self.wdm_data_mode)
        self.h = self.data_mode_convert(self.h, self.wdm_data_mode)
            # self.df_matrix = torch.from_numpy(self.df_matrix).to(self.device)
            # self.h = torch.from_numpy(self.h).to(self.device)
        sigout = wdm(sigin, len(sigin), self.infor_print)
        return sigout

    def wdm_demultiplexing(self, sigin, para_update, *args, **kwargs): 
        """Implement the demultiplexing process.

        The wrapper function for demultiplexing. First prepare the filter,
        then perform the signal operation and this method will obtain the signal 
        from a selected channel.

        Parameters
        ----------
        sigin : list
            The multiplexed signals.
        para_update : int,{0,1}
            This parameter determines whether to change the configuration of the filter 
            used in the demultiplexing process.When para_update = 0, the same filter 
            as in the multiplexing process will be used.
        **kwargs : dict
            New arguments of the filter when para_update = 1.
        
        Returns
        -------
        sigout : list
            The demultiplexed signal.Only the signal from the selected channel is preserved. 


        """
        sigout = []
        if para_update:
            for key in kwargs:
                self.__dict__[key] = kwargs[key]
            if self.channel_num % 2 == 0:
                self.df_matrix = self.channel_space *np.linspace(-(self.channel_num/2-1), self.channel_num/2, self.channel_num)
            else:
                self.df_matrix = self.channel_space *np.linspace(-(self.channel_num-1)/2, (self.channel_num-1)/2, self.channel_num)
            # prepare the WSS filter
            self.get_transfer_func(fft_num = self.fft_num, sam_rate = self.sam_rate, cut_off = self.cut_off)
            self.h = self.filter_func()

        self.simu_time = self.data_mode_convert(self.simu_time, self.dwdm_data_mode)
        self.df_matrix = self.data_mode_convert(self.df_matrix, self.dwdm_data_mode)
        self.h = self.data_mode_convert(self.h, self.dwdm_data_mode)

        if self.dwdm_data_mode == 'tensor':
            phase_array = torch.exp(- 1j * 2 * np.pi * self.df_matrix[self.cut_idx]  * self.simu_time)
        else:
            phase_array = np.exp(- 1j * 2 * np.pi * self.df_matrix[self.cut_idx]  * self.simu_time)
        sig_power_w = (10 ** ((self.sig_p - 30 ) / 10))
        for i_p in range(len(sigin)):
            x = self.data_mode_convert(sigin[i_p], self.dwdm_data_mode)
            x = phase_array * x
            x_filtered = self.filter_in_freq(x, self.h)
            x_filtered = norm_1d(x_filtered, sig_power_w / 2)
            sigout.append(x_filtered)

        return sigout