import numpy as np
import torch
import matplotlib.pyplot as plt
from ...comm_tools.filter_design import filter_func as f
from ... base.base_dsp import DSP_Base_Module

class Digital_filter_para_freq(DSP_Base_Module):
    '''
    Digital_filter_para_Freq class is a subclass of DSP_Base_Module, which is used to obtain the 
    sampling frequency array required by FFT in the filter design process and obtain the FIR sequence 
    for the filter according to the set filter design parameters. This class is used to filter the 
    signal through a filter.
    '''

    def __init__(self, filter_type, data_mode = 'numpy', *args, **kwargs):
        """
        Initialization function of Digital_filter_para_Freq class. This function sets the filter type 
        and filter parameters according to the input parameters
        Parameters:
            filter_type (str): Filter type. Defaults to 'brickwall'.
                brickwall   : Brickwall filter
                butter      : Butterworth filter
                bessel      : Bessel filter
                gaussian    : Gaussian filter
                rc          : Raised cosine filter
                rrc         : Root raised cosine filter
            filter_para (dict):
                gpass(float): Butterworth parameter. The maximum loss in the passband (dB). Defaults to 5.
                gstop(float): Butterworth parameter. The minimum attenuation in the stopband (dB). Defaults to 40.
                order(int)  : Bessel parameter. The order of the filter. Defaults to 10.
                beta (float): RC and RRC parameter. The rolloff factor.  Defaults to 0.1.
        Functions:
                get_freq_lpf:
                get_freq_rcfilter:
                filter_in_freq:
                plot:
        Raises:
            AttributeError: If filter_type are not supported.
        """
        super().__init__()
        self.filter_type = filter_type
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')
        
    def init(self, *args, **kwargs):
        '''
        For different filter types, set the corresponding parameters required by each filter
        Main Parameters:
        self.gpass: float
        Maximum passband attenuation, butter filter parameters.
        self.gstop: float
        Minimum stopband attenuation, butter filter parameters.
        self.order: int
        Filter order, Bessel filter parameters.
        self.beta: float
        Roll off zone width, RC and RRC filter parameters.

        Returns:
        This function has no return value
        Raises: 
        AttributeError
        Triggered when an unsupported filter type is input:'such filter is not supported '
        '''
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
        if self.filter_type == 'butter':
            self.gpass = self.__getarr__('gpass', 5)
            self.gstop = self.__getarr__('gstop', 40)
        elif self.filter_type == 'bessel':
            self.order = self.__getarr__('order', 10)
        elif self.filter_type == 'rc' or self.filter_type == 'rrc':
            self.beta = self.__getarr__('beta', 0.1)
        elif self.filter_type == 'brickwall' or self.filter_type == 'gaussian':
            pass
        else:
            raise AttributeError('Such filter is not supported')

    def forward_pass(self, x):
        '''
        Signal filtering execution function. The input signal is filtered in the frequency domain 
        by Fourier transform, and the filtered signal sequence is returned.
        Parameters:
        x: ndarray
        Input signal sequence.

        Returns:
        y: ndarray
        Filtered signal sequence
        '''
        fft_num = x.shape[0]
        x = self.data_mode_convert(x)
        self.__get_freq__(fft_num)
        y = self.__filter_in_freq__(x, self.h)
        return y

    def plot(self, fft_num, sam_rate, h, name = 'filter.png'):
        '''
        Display the obtained filter 
        Parameters:
        fft_num: int
        Fast Fourier transform length
        sam_rate: int
        Sampling rate. Important parameters of filter design
        h: array_like
        FIR sequence of the designed filter
        name: str, optional
        Name of the saved image

        Returns:
        This function has no return value
        '''
        if fft_num % 2 == 0:
            num = fft_num / 2
            f = np.arange(- num + 1, num + 1) * sam_rate / (fft_num-1)
        else:
            num = (fft_num - 1) / 2
            f = np.arange(- num, num + 1) * sam_rate / (fft_num-1)
        plt.figure()
        plt.subplot(1,2,1)
        plt.plot(f, np.abs(h))
        plt.grid()
        plt.subplot(1,2,2)
        plt.plot(f, 20*np.log10(np.abs(h)))
        plt.grid()
        plt.axhline(0.5, color='red')  # 3dB
        plt.savefig(name, dpi = 600)
        plt.show()

    def get_freq(self, fft_num):
        '''
        Get the entry of filter FIR sequence function. Actual call __get_freq__ function
        Parameters:
        fft_num: int
        Fast Fourier transform length

        Returns:
        Return results of __get_freq__ function
        '''
        return self.__get_freq__(fft_num)
    
    def filter_in_freq(self, x, h):
        '''
        Function entry for filtering the signal. Actual call __filter_in_freq__ function
        Parameters:
        x: array_like
        Input signal sequence.
        h: array_like
        FIR sequence of the designed filter
        
        Returns:
        Return results of __filter_in_freq__ function
        '''
        x = self.data_mode_convert(x)
        h = self.data_mode_convert(h)
        return self.__filter_in_freq__(x, h)

    def __filter_in_freq__(self, x, h):
        '''
        The actual execution function of filtering the signal. The input signal is low-pass 
        filtered in the frequency domain by Fourier transform, and the filtered signal sequence 
        is returned
        Parameters:
        x: array_like
        Input signal sequence.
        h: array_like
        FIR sequence of the designed filter
        
        Returns:
        y: ndarray
        Filtered signal sequence
        '''
        if self.data_mode == 'tensor':
            x_fft = torch.fft.fftshift(torch.fft.fft(x))
            y_fft = x_fft * h
            y = torch.fft.ifft(torch.fft.ifftshift(y_fft))
        else:
            x_fft = np.fft.fftshift(np.fft.fft(x))
            y_fft = x_fft * h
            y = np.fft.ifft(np.fft.ifftshift(y_fft))
        return y

    def __get_freq__(self, fft_num):
        r"""
        Set the cut-off frequency, upsampling rate and other parameters. For different filter types, 
        call __get_freq_lpf__ or __get_freq_rcfilter__ function to generate the corresponding filter 
        FIR sequence
        
        filter_type == rc or rrc:
            Args:
            fft_num (int)   : Signal numbers in FFT
            data_mode (str) : numpy or tensor
            upsam (float)   : Up sample rate = Sample rate / Symbol rate
            device (str)    : Tensor at device
        filter_type == other :
            Args:
            fft_num (int)   : Signal numbers in FFT
            data_mode (str) : numpy or tensor
            cut_off (float) : 3 dB bandwidth (GHz)
            sam_rate (float): Sampling rate (GHz)
            device (str)    : Tensor at device
        """
        if self.filter_type == 'rc' or self.filter_type == 'rrc':
            upsam = self.__getarr__('upsam', 2)
            self.__get_freq_rcfilter__(fft_num, upsam)
        else:
            sam_rate = self.__getarr__('sam_rate', 100)
            cut_off = self.__getarr__('cut_off', 50)
            self.__get_freq_lpf__(fft_num, cut_off, sam_rate)
        self.h = self.data_mode_convert(self.h)    
        return self.h
        
    def __get_freq_lpf__(self, fft_num, cut_off, sam_rate, **kwargs):
        '''
        By calling functions in filter_func of comm_tools, generate FIR sequences of Brickwall, Bessel, Butter, Gaussian low-pass filters.
        Parameters:
        fft_num: int
        Fast Fourier transform length.
        cut_off: float
        Ideal low-pass filter cut-off frequency.
        sam_rate: int
        Low pass filter sampling rate.
        **kwargs
        Pass in an indefinite number of key value pair parameters.
        
        Returns:
        This function has no return value.
        '''
        if self.filter_type == 'brickwall':
            self.h = f.brikwall_filter(fft_num, cut_off, \
                sam_rate, self.data_mode, device = self.device)
        if self.filter_type == 'butter':
            stop_band = kwargs.get('stop_band', cut_off * 1.1)
            self.h = f.butter_filter(fft_num, cut_off, stop_band,\
                sam_rate, self.gpass, self.gstop, self.data_mode, device = self.device)
        if self.filter_type == 'bessel':
            self.h = f.bessel_filter(fft_num, cut_off, \
                sam_rate, self.order, self.data_mode, device = self.device)
        if self.filter_type == 'gaussian':
            self.h = f.gaussian_filter(fft_num, cut_off, \
                sam_rate, self.data_mode, device = self.device)
        
    def __get_freq_rcfilter__(self, fft_num, upsam): 
        '''
        By calling functions in filter_func of comm_tools, generate FIR sequences of rc, rrc filters.
        Parameters:
        fft_num: int
        Fast Fourier transform length.
        up_sam: float
        Upsampling rate, which is equal to the sampling rate divided by the symbol rate.
        
        Returns:
        This function has no return value.
        '''
        if self.filter_type == 'rc':
            self.h = f.rc_filter(fft_num, self.beta, upsam, \
                self.data_mode, device = self.device)
        if self.filter_type == 'rrc': 
            self.h = f.rrc_filter(fft_num, self.beta, upsam, \
                self.data_mode, device = self.device)