import numpy as np
import torch
import IFTS.fiber_simulation.comm_tools.calculation as calcu
from IFTS.fiber_simulation.comm_tools.filter_design import filter_func as f
from IFTS.fiber_simulation.base.base_dsp import DSP_Base_Module

class LPF(DSP_Base_Module):
    def __init__(self, mode, sam_rate, fft_num, data_mode = 'numpy', *args, **kwargs):
        r"""
            Initialization function of the LPF class.
            This function initializes the low-pass filter parameters.
            Parameters: mode:str
                            Type of low pass filter.
                        sam_rate:str
                            Sample rate.
                        fft_num:int
                            Length of Fourier transform .
                        data_mode:str
                            data mode. Defaults to numpy.
        """
        super().__init__()
        self.mode = mode
        self.sam_rate = sam_rate
        self.fft_num = fft_num
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')

    def init(self, *args, **kwargs):
        r"""
            This function generate the filter FIR sequence.
            Raises:
                AttributeError: 'Such filter is not supported'
                If an unsupported filter type is inputted.
        """
        if self.mode == 'butter':
            self.gpass = self.__getarr__('gpass', 5)    # Maximum passband loss(dB)
            self.gstop = self.__getarr__('gstop', 40)    # Minimum stopband attenuation(dB)
        elif self.mode == 'bessel':
            self.order = self.__getarr__('order', 10)    # The order of the filter
        elif self.mode == 'rc' or self.mode == 'rrc':
            self.beta = self.__getarr__('beta', 0.1)    # Roll-off zone width
        elif self.mode == 'brickwall' or self.mode == 'gaussian':
            self.cut_off = kwargs.get('cut_off')
        else:
            raise AttributeError('Such filter is not supported')
        # get the filter FIR sequence
        self.h = self.__get_freq__(self.fft_num)
        for key in kwargs:
            self.__dict__[key] = kwargs[key]

    def forward_pass(self, x):
        r"""
            This function executes the low-pass filter module.
            The input signal is filtered in the frequency domain.
            Parameters: x: tensor
                            Input signal sequence.
            Return:     y: tensor
                            Filtered signal sequence.
        """
        # convert data mode
        x = self.data_mode_convert(x)
        self.h = self.data_mode_convert(self.h)

        # filter input signal in the frequency domain
        if self.data_mode == 'tensor':
            x_fft = torch.fft.fftshift(torch.fft.fft(x, dim=-1), dim=-1)
            y_fft = x_fft * self.h
            y = torch.fft.ifft(torch.fft.ifftshift(y_fft, dim=-1), dim=-1)
        else:
            x_fft = np.fft.fftshift(np.fft.fft(x, axis=-1), axes=-1)
            y_fft = x_fft * self.h
            y = np.fft.ifft(np.fft.ifftshift(y_fft, axes=-1), axis=-1)
        return y

    r"""
        Note the following three get_freq functions.
            mode == rc or rrc:
                Args:
                fft_num (int)   : Signal numbers in FFT
                data_mode (str) : numpy or tensor
                upsam (float)   : Up sample rate = Sample rate / Symbol rate
                device (str)    : Tensor at device
            mode == other :
                Args:
                fft_num (int)   : Signal numbers in FFT
                data_mode (str) : numpy or tensor
                cut_off (float) : 3 dB bandwidth (GHz)
                sam_rate (float): Sampling rate (GHz)
                device (str)    : Tensor at device
        Return:  self.h(tensor) : Filter FIR sequence
     """
    def __get_freq__(self, fft_num):
        # Determine the filter type
        if self.mode == 'rc' or self.mode == 'rrc':
            upsam = self.__getarr__('upsam')
            self.__get_freq_rcfilter__(fft_num, upsam)
        else:
            sam_rate = self.__getarr__('sam_rate')
            cut_off = self.__getarr__('cut_off')
            self.__get_freq_lpf__(fft_num, cut_off, sam_rate)
        return self.h    
        
    def __get_freq_lpf__(self, fft_num, cut_off, sam_rate, **kwargs):
        # mode == other
        if self.mode == 'brickwall':
            self.h = f.brickwall_filter(fft_num, cut_off, \
                sam_rate, self.data_mode, device = self.device)
        if self.mode == 'butter':
            stop_band = kwargs.get('stop_band', cut_off * 1.1)
            self.h = f.butter_filter(fft_num, cut_off, stop_band,\
                sam_rate, self.gpass, self.gstop, self.data_mode, device = self.device)
        if self.mode == 'bessel':
            self.h = f.bessel_filter(fft_num, cut_off, \
                sam_rate, self.order, self.data_mode, device = self.device)
        if self.mode == 'gaussian':
            self.h = f.gaussian_filter(fft_num, cut_off, \
                sam_rate, self.data_mode, device = self.device)
        
    def __get_freq_rcfilter__(self, fft_num, upsam):
        # mode == rc or rrc
        if self.mode == 'rc':
            self.h = f.rc_filter(fft_num, self.beta, upsam, \
                self.data_mode, device = self.device)
        if self.mode == 'rrc': 
            self.h = f.rrc_filter(fft_num, self.beta, upsam, \
                self.data_mode, device = self.device)

    