import torch
import numpy as np
from ....fiber_simulation.comm_tools import calculation as calcu
from scipy import signal
from scipy.special import erfc

def bessel_filter(fft_num, cut_off, sam_rate, order, *args, **kwargs):
    r"""
    Bessel filter
    The generating function of Bessel filter, returns the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    cut_off: float
    Cut off frequency of Bessel filter
    sam_rate: int
    Ideal low pass filter sampling rate
    order: int
    The parameter of Bessel filter, filter order

    Returns:
    h: array_like
    FIR sequence of Bessel filter
    """ 
    f = calcu.digital_freq(fft_num, sam_rate)
    b, a = signal.bessel(order, cut_off, 'low', analog=True, norm='phase')
    _, h = signal.freqs(b, a, worN = f)
    # h = np.abs(h) / np.max(np.abs(h))
    h = np.abs(h)
    return h

def brickwall_filter(fft_num, cut_off, sam_rate, *args, **kwargs):
    r"""
    Brickwall filter
    The generating function of Brickwall filter, returns the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    cut_off: float
    Cut off frequency of Brickwall filter
    sam_rate: int
    Ideal low pass filter sampling rate

    Returns:
    h: array_like
    FIR sequence of Brickwall filter
    """ 
    f = calcu.digital_freq(fft_num, sam_rate)
    h = np.ones(fft_num)
    h = np.where(f > cut_off, 0.0, h)
    h = np.where(f <-cut_off, 0.0, h)
        
    return h

def butter_filter(fft_num, cut_off, stop_band, sam_rate, gpass=5, gstop=40, *args, **kwargs):
    r"""
    Butterworth filter
    The generating function of Butterworth filter, returns the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    cut_off: float
    Cut off frequency of Butterworth filter
    stop_band: float
    Stopband boundary frequency of Butterworth filter
    sam_rate: int
    Ideal low pass filter sampling rate
    gpass: float, optional
    Butterworth filter parameters, maximum passband attenuation
    gstop: float, optional
    Butterworth filter parameters, minimum stopband attenuation

    Returns:
    h: array_like
    FIR sequence of Butterworth filter
    """ 
    f = calcu.digital_freq(fft_num, sam_rate)
    N, Wn = signal.buttord(wp=cut_off, ws=stop_band, gpass=gpass, gstop=gstop, analog=True)
    b, a = signal.butter(N, Wn, 'low', analog=True)
    w, h = signal.freqs(b, a, worN = f)
    h = np.abs(h) / np.max(np.abs(h))

        
    return h

def gaussian_filter(fft_num, cut_off, sam_rate, *args, **kwargs):
    r"""
    Gaussian filter
    The generating function of Gaussian filter, returns the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    cut_off: float
    Cut off frequency of Gaussian filter
    sam_rate: int
    Ideal low pass filter sampling rate

    Returns:
    h: array_like
    FIR sequence of Gaussian filter
    """ 
    FWHM = 2 * cut_off
    sigma = (fft_num-1) * FWHM / (sam_rate * 2*np.sqrt(2*np.log(2)))
    h = signal.windows.gaussian(fft_num, std=sigma)
    h = h / np.max(h)
        
    return h

def rc_filter(fft_num, beta, upsam, *args, **kwargs): 
    '''
    The generating function of the raised cosine filter, returns the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    beta: float
    Roll off zone width, raised cosine filter parameters
    upsam: int
    Upsampling rate, which is equal to the sampling rate divided by the symbol rate

    Returns:
    h: array_like
    FIR sequence of raised cosine filter
    '''
    sam_rate = 1 / upsam
    h = np.zeros(fft_num) 
    f = calcu.digital_freq(fft_num, sam_rate)
    index = np.concatenate((np.where(f<= -(1 + beta)*sam_rate/2),\
        np.where(f > (1 + beta)*sam_rate/2)),1)
    h[index] = 0
    index = np.concatenate((np.where((f> -(1 + beta)*sam_rate/2) &\
        (f<= -(1 - beta)*sam_rate/2)),\
            np.where((f >(1-beta)*sam_rate/2) &\
                (f<= (1+beta)*sam_rate/2))),1)
    if beta == 0:
        h[index] = 1/2
    else:
        h[index] = 1/2 * (1 + np.cos( np.pi / sam_rate / beta *\
            (np.abs( f[index]) - (1-beta)*sam_rate/2 )))
    index=np.where((f> -(1-beta)*sam_rate/2) & (f<= (1-beta)*sam_rate/2))
    h[index]= np.sqrt(1)

    return h

def rrc_filter(fft_num, beta, upsam, *args, **kewargs): 
    '''
    The generating function of the root raised cosine filter and return the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    beta: float
    Roll off zone width, root raised cosine filter parameters
    upsam: int
    Upsampling rate, which is equal to the sampling rate divided by the symbol rate

    Returns:
    h: array_like
    FIR sequence of root raised cosine filter
    '''
    sam_rate = 1 / upsam
    h = np.zeros(fft_num) 
    f = calcu.digital_freq(fft_num, 1)
    index = np.concatenate((np.where(f<= -(1 + beta)*sam_rate/2),\
        np.where(f > (1 + beta)*sam_rate/2)),1)
    h[index] = 0
    index = np.concatenate((np.where((f> -(1 + beta)*sam_rate/2) &\
        (f<= -(1 - beta)*sam_rate/2)),\
            np.where((f >(1-beta)*sam_rate/2) &\
                (f<= (1+beta)*sam_rate/2))),1)
    if beta == 0:
        h[index] = np.sqrt(1/2)
    else:
        h[index] = np.sqrt(1/2 * (1 + np.cos(np.pi/sam_rate/beta*\
            (np.abs(f[index])-(1-beta)/2*sam_rate))))
    index=np.where((f> -(1-beta)*sam_rate/2) & (f<= (1-beta)*sam_rate/2))
    h[index]= np.sqrt(1)

    return h

def wss_filter(fft_num, sam_rate, bandwidth, steepness, *args, **kwargs):
    r"""
    The generating function of WSS filter, returns the FIR sequence of the corresponding filter
    Parameters:
    fft_num: int
    The length of FFT(Fast Fourier Transform)
    sam_rate: int
    Ideal low pass filter sampling rate
    bandwidth: float
    Bandwidth of WSS filter
    steepness: float
    Steepness of WSS filter

    Returns:
    h: array_like
    FIR sequence of WSS filter

    References:
    Zhai, Zhiqun, et al. "An Interpretable Mapping from a Communication System to a 
    Neural Network for Optimal Transceiver-Joint Equalization." Journal of Lightwave Technology (2021).
    """
    f = calcu.digital_freq(fft_num, sam_rate)
    h = sigma = steepness / (2 * np.sqrt(2 * np.log(2)))
    h = 0.5 * sigma * np.sqrt(2 * np.pi) * \
        (erfc((bandwidth/2 - f) / (np.sqrt(2) * sigma))\
            - erfc((-bandwidth/2 - f) / (np.sqrt(2) * sigma)))
        
    return h