import numpy as np
import torch
from ....fiber_simulation.comm_tools import calculation as calcu
from ....fiber_simulation.base.base_dsp import DSP_Base_Module

class CDC(DSP_Base_Module):
    def __init__(self, mode, beta, sam_per_sym, sam_rate, fft_num, data_mode = 'numpy', *args, **kwargs):
        r"""
            Initialization function of the CDC class.
            This function initializes the chromatic dispersion compensation parameters.
            Parameters: mode：str
                             The type of dispersion compensation algorithm.
                             this function only supports two compensation modes:
                             FD 1X1; FD 2X2
                        beta：list
                            A list of dispersion parameters,
                            which are the Taylor series of propagation constant β
                            at the center frequency.
                        sam_per_sym：int
                            The number of sample points per symbol.
                        sam_rate：int
                            Sample rate, which is the number of sampling points per second.
                        fft_num：int
                             Length of Fourier transform .
                        data_mode：str
                            data mode. Defaults to numpy.
        """
        super().__init__()
        self.mode = mode
        self.beta = beta
        self.sam_rate = sam_rate    # GHz
        self.sam_per_sym = sam_per_sym
        self.fft_num = fft_num
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')
        self.Constant_pi = kwargs.get('Constant_pi', np.pi)
        self.w = 2 * self.Constant_pi * \
            calcu.digital_freq(self.fft_num,\
                self.sam_rate * 1e-3)    # radian frequency

    def init(self, cdc_length, cut_idx, *args, **kwargs):
        r"""
            Calculate the dispersion phase shift
            based on the input dispersion length and center frequency.
            Parameters: cdc_length：float64
                            Chromatic dispersion length.
                            cdc_length = total_length - pre_length
                        cut_idx：int
                            Center channel index.
                            WDM has multiple channels,
                            for there are three channels, cut_idx = 1.
            Raises:
                AttributeError: 'Such mode is not supported'
                If the input algorithm mode is not 'FD'.
        """
        self.cdc_length = cdc_length
        self.cut_idx = cut_idx
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
        self.domain, self.tap_type = self.mode.split('_')

        # calculate phase factor
        if self.domain == 'FD':
            phase_factor_freq = self.beta[0][self.cut_idx]\
                + self.beta[1][self.cut_idx] * (self.w)\
                    + self.beta[2][self.cut_idx] * (self.w ** 2) / 2\
                        + self.beta[3] * (self.w ** 3) / 6  
            phase_factor_freq = self.data_mode_convert(phase_factor_freq)

            # calculate phase shift
            if self.data_mode == 'numpy':
                self.cdc_tap = np.exp(1j * phase_factor_freq * cdc_length)
            else: 
                self.cdc_tap = torch.exp(1j * phase_factor_freq * cdc_length)
        else:
            raise AttributeError('Such mode is not supported')

    def forward_pass(self, sig_in):
        r"""
            This function executes the chromatic dispersion compensation module.
            the __cdc_in_freq__ function is called to compensate the signal.
            Parameters: sig_in：list
                            Input signal sequence.
            Return:     sig_out：list
                            Dispersion-compensated signal sequence.
            Raises:
                RuntimeError: 'Such mode is not supported'
                If the input algorithm mode is not '1x1' or '2x2'.
        """
        sig_out = []
        for i_p in range(len(sig_in)):
            # Compensate the two polarization state signals separately
            x = sig_in[i_p]
            if self.tap_type == '1x1':
                if self.domain == 'FD':
                    yi = self.__cdc_in_freq__(x.real)
                    yq = self.__cdc_in_freq__(x.imag)
                y = yi + 1j * yq
            elif self.tap_type == '2x2':
                if self.domain == 'FD':
                    y = self.__cdc_in_freq__(x)
            else:
                raise RuntimeError('Such mode is not supported')
            sig_out.append(y)
        return sig_out

    def __cdc_in_freq__(self, x):
        r"""
            The Signal is dispersed compensation
            in the frequency domain using the Fourier transform.
            Parameters: x：list
                            Input signal sequence.
            Return:     y：list
                            Dispersion-compensated signal sequence.
        """
        if self.data_mode == 'tensor':
            # If the input signal is real or complex
            if torch.is_complex(x):
                x_fft = torch.fft.fftshift(torch.fft.fft(x))
                y_fft = x_fft * self.cdc_tap
                y = torch.fft.ifft(torch.fft.ifftshift(y_fft))
            else:
                x_fft = torch.fft.fftshift(torch.fft.rfft(x))
                y_fft = x_fft * self.cdc_tap
                y = torch.fft.irfft(torch.fft.ifftshift(y_fft))
        else:
            if np.iscomplex(x):
                x_fft = np.fft.fftshift(np.fft.fft(x))
                y_fft = x_fft * self.cdc_tap
                y = np.fft.ifft(np.fft.ifftshift(y_fft))
            else:
                x_fft = np.fft.fftshift(np.fft.rfft(x))
                y_fft = x_fft * self.cdc_tap
                y = np.fft.irfft(np.fft.ifftshift(y_fft))
        return y
        
    