import numpy as np
import torch
from . import cma
from ... base.base_dsp import DSP_Base_Module

class Vanilla(DSP_Base_Module):

    def __init__(self, half_taps_num, upsam = 2, out_sym_num = 1,\
         block_size = 1, mode = 'TD_2x2', lr_optim = 'constant',\
           data_mode = 'tensor', infor_print = 1, *args, **kwargs):
        r"""
            Initialization function of the adaptive filter design class.
            This function initializes the adaptive filter parameters.
            Parameters: half_taps_num: int
                            Half the number of the filter taps.
                        upsam: int
                            upsample rate, defaults to 2.
                        out_sym_num: int
                            Number of output symbols, defaults to 1.
                        block_size: int
                            Filter block size, defaults to 1.
                            It is generally processed in parallel in the actual filtering.
                        mode: str
                            Algorithm mode, defaults to 'TD_2x2',
                            Which means Time domain calculation, dual input and dual output.
                        lr_optim: str
                            Learning rate optimization strategy. defaults to 'constant'.
                        data_mode: str
                            Data mode. Defaults to 'tensor'.
                        infor_print: int
                            The information print parameter, defaults to 1.
        """
        super().__init__()
        self.lr_optim = lr_optim
        self.upsam = upsam
        self.out_sym_num = out_sym_num
        self.mode = mode
        self.domain, self.fir_type  = mode.split("_")    # TD_2x2
        in_arr_dim, out_arr_dim = self.fir_type.split("x")
        self.in_arr_dim, self.out_arr_dim = int(in_arr_dim), int(out_arr_dim)
        self.half_taps_num = half_taps_num
        self.block_size = block_size
        self.infor_print = infor_print
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')   
        self.iter_idx = 0
    
    def init(self, sam_num, algo_type, lr, train_num, h_ortho = 0,\
        train_epoch = 1, cma_pretrain = 1, pre_train_iter =2, tap_init_mode = 1,\
             tap_init_value = 1.0, radius_idx=-1, *args, **kwargs):
        r"""
            This function loads important algorithm parameters.
            Parameters: sam_num：int
                            The sample number.
                        algo_type：str
                            The type of adaptive filtering algorithm,
                            which determines how the filter parameters are updated.
                        lr：float
                            Learning rate, which determines the rate of gradient descent.
                        train_num：int
                            The data number for the filter training.
                        h_ortho：int
                            Orthogonalization parameter,
                            which decides whether to perform filter orthogonalization
                            after CMA pre-training.
                        train_epoch：int
                            The number of traversing the entire dataset in training,
                            defaults to 1.
                        cma_pretrain：int
                            Whether to use the CMA algorithm for pre-training,
                            defaults to 1, which means true.
                        pre_train_iter：int
                            The iteration numbers of pre-training, defaults to 2.
                        tap_init_mode：int
                            Tap initialization mode, defaults to 1.
                        tap_init_value：float
                            tap initialization value, defaults to 10.0
            Raises:
                RuntimeError: 'The output symbol number is not suitable for the upsample rate'
        """

        #loading the para
        self.sam_num = sam_num
        self.train_num = train_num
        self.train_epoch = train_epoch
        self.middle_sam_num = int(self.out_sym_num * self.upsam)
        if self.middle_sam_num / self.out_sym_num != self.upsam:
            raise RuntimeError('The output symbol number is not suitable for the upsample rate')
        self.taps_num = self.half_taps_num * 2 + self.out_sym_num
        self.lr = lr
        self.h_ortho = h_ortho
        self.algo_type = algo_type
        if self.algo_type == 'mma':
            self.cma_pretrain = cma_pretrain            
        else:
            self.cma_pretrain = 0
        self.pre_train_iter = pre_train_iter
        self.tap_init_mode = tap_init_mode
        self.tap_init_value = tap_init_value
        self.radius_idx = int(radius_idx)

        # Iteration initialization
        self.__iter_init__()
        if self.domain == 'TD':
            # Filter FIR sequence initialization
            self.__fir_init__()
        for key in kwargs:
            self.__dict__[key] = kwargs[key]

    def forward_pass(self, x, y, sym_map):
        r"""
            This function executes the adaptive filter module.
            The input signal is filtered in the time domain.
            Parameters: x,y: tensor
                            Signal sequence of two polarization states.
                        sym_map: ndarray
                            16QAM symbol map.
            Return:     [x, y]: list
                            Adaptive filtered signal sequence.
            Raises:     RuntimeError: 'Algorithm is not supported'
                            If the input algorithm mode is not cma or mma.
        """
        self.sym_map = sym_map
        if not hasattr(self, 'adaptive_lr'):
            self.adaptive_lr = 0

        # Three radius lengths for 16QAM
        if self.algo_type == 'cma' or self.algo_type == 'mma':
            self.cma_radius, self.cma_radius_max, self.cma_radius_min \
                = cma.cma_radius(self.sym_map, self.data_mode, device = self.device)
            if self.radius_idx > len(self.cma_radius):
                self.radius_idx = -1
            # execute the adaptive filter in cma/mma algorithm
            x, y = cma.cma_main(x, y, self)
            return [x, y]
        else:
            raise RuntimeError('Algorithm is not supported')

    def step(self, xin, xout, error):    # update parameters
        # update the filter taps
        self.__lms_update__(xin, xout, error)
        # update the lr
        self.__lr_step__()
        self.iter_idx += 1    # The number of training iterations

    def apply_filter(self, x):
        r"""
            This function executes the signal filtering.
            Parameters: x: list
                            Chunked signal before filtering in two polarization states.
            Return:     out：ndarray
                            Chunked signal after filtering.
                        x：ndarray
                            Chunked signal before filtering,
                            the two polarized signals are stitched together
                            into a high-dimensional array for subsequent error calculations.
            Raises:     RuntimeError: 'Inf values'
        """
        r"""
            Apply matrix operation (M, N, 2)^T * (M, N, 1) = (M, Out, 1)
            I: input arr dimension
            O: output arr dimension
            M: taps number
            N: out symbol num
            B: block_size
            Apply matrix operation [O, I, M, N].conj * (O, B, M) = (I, B, N)
        """
        if type(x[0]) == torch.Tensor:
            if self.in_arr_dim == 2:
                x = torch.stack((x[0], x[1]), dim=0)
                out = torch.einsum('iomn,ibm->obn', self.h.conj(), x)
            elif self.in_arr_dim == 4 and self.out_arr_dim == 2:
                x = torch.stack((x[0].real, 1j * x[0], x[1].real, 1j * x[1]), dim=0)
                out = torch.einsum('iomn,ibm->obn', self.h.conj(), x)
            elif self.in_arr_dim == 4 and self.out_arr_dim == 4:
                x = torch.stack((x[0].real, x[0], x[1].real, x[1]), dim=0)
                out = torch.einsum('iomn,ibm->obn', self.h, x)
            if torch.isnan(out).sum() != 0:
                raise RuntimeError('Nan values')

        elif type(x[0]) == np.ndarray:
            if self.in_arr_dim == 2:
                xin = x[0][np.newaxis]
                yin = x[1][np.newaxis]
                x = np.concatenate((xin, yin), axis=0)
            elif self.in_arr_dim == 4 and self.out_arr_dim == 2:
                xi_in = x[0].real[np.newaxis]
                xq_in = 1j * x[0].imag[np.newaxis]
                yi_in = x[1].real[np.newaxis]
                yq_in = 1j * x[1].imag[np.newaxis]
                x = np.concatenate((xi_in, xq_in, yi_in, yq_in), axis=0)
            elif self.in_arr_dim == 4 and self.out_arr_dim == 4:
                xi_in = x[0].real[np.newaxis]
                xq_in = x[0].imag[np.newaxis]
                yi_in = x[1].real[np.newaxis]
                yq_in = x[1].imag[np.newaxis]
                x = np.concatenate((xi_in, xq_in, yi_in, yq_in), axis=0)
            out = np.einsum('iomn,ibm->obn', np.conj(self.h), x)
            if len(np.where(np.abs(out) ** 2 == np.inf)[0]) != 0:
                raise RuntimeError('Inf values')
        return out, x

    def padding(self, x):
        r"""
            This function executes the signal padding in signal head and tail.
            Parameters: x: ndarray
                            Input signal sequence.
            Return:     x：ndarray
                            Signal sequence after padding.
        """
        if type(x) == torch.Tensor:
            x = torch.cat((x[-self.padding_head_num:], x, x[0:self.padding_tail_num]), dim=0)
        else:
            x = np.concatenate((x[-self.padding_head_num:], x, x[0:self.padding_tail_num]), axis=0)
        return x

    def __iter_init__(self):
        r"""
            This function initializes the iteration parameters of the adaptive filtering.
            Args: sam_num_in_block：int
                      The number of sample points per signal block.
                  block_idx: ndarray
                      Index of signal matrix.
                  train_block_num: int
                      The number of training signal blocks.
                  track_block_num: int
                      The number of tracking signal blocks.
        """
        self.sam_num_in_block = self.block_size * self.middle_sam_num

        # Define the block idx
        self.block_idx = np.zeros((self.block_size, self.taps_num))
        self.tx_block_idx = np.zeros((self.block_size, self.middle_sam_num))
        for i in range(self.block_size):
            self.block_idx[i] = np.arange(self.taps_num) + self.middle_sam_num * i
            self.tx_block_idx[i] = np.arange(self.middle_sam_num) + self.middle_sam_num * i
        self.block_idx = self.block_idx.astype(np.int32)
        self.tx_block_idx = self.tx_block_idx.astype(np.int32)

        # Define the MIMO train number and track number
        self.train_block_num = int(self.train_num / self.sam_num_in_block)    
        self.train_tot_num = self.train_block_num * self.train_epoch
        self.track_num = self.sam_num
        self.track_block_num = int(np.ceil(self.track_num / self.sam_num_in_block))
        self.out_sym_tot_num = int(self.sam_num / self.upsam)

        # Define the padding number
        self.padding_tail_num = int(self.track_block_num * self.sam_num_in_block - self.sam_num + self.half_taps_num)
        self.padding_head_num = self.half_taps_num
    
    def __fir_init__(self):
        r"""
            This function initializes the FIR and error matrix of the adaptive filtering.
            MIMO array dim: [In_arr_dim, Out_arr_dim, Sam_in_dim, Sym_out_dim]
            Err dim: [Out_arr_dim, Sam_in_dim, Sym_out_dim]
            Raise: RuntimeError: 'MIMO taps number must be an odd number'
                   NotImplementedError: 'Initialization of MIMO type is not implemented'
        """

        if self.taps_num % 2 != 1:
            raise RuntimeError('MIMO taps number must be an odd number')
        # Define the FIR and error matrix
        if self.data_mode == 'tensor':
            self.h = torch.zeros((self.in_arr_dim, self.out_arr_dim, self.taps_num, self.out_sym_num), device=self.device) + 0j
            self.err = torch.zeros((self.train_tot_num, self.out_arr_dim, self.out_sym_num), device=self.device)
            self.err_swa = torch.zeros((self.train_num, self.out_arr_dim, self.out_sym_num), device=self.device)
        elif self.data_mode == 'numpy':
            self.h = np.zeros((self.in_arr_dim, self.out_arr_dim, self.taps_num, self.out_sym_num))
            self.err = np.zeros((self.train_tot_num, self.out_arr_dim, self.out_sym_num))
            self.err_swa = np.zeros((self.train_num, self.out_arr_dim, self.out_sym_num))
        # taps initialization
        for i in range(self.out_sym_num):
            if self.tap_init_mode == 1:
                if self.in_arr_dim == 2 and self.out_arr_dim == 2:
                    self.h = self.h + 0j
                    self.h[0, 0, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[1, 1, self.half_taps_num + i, i] = self.tap_init_value
                elif self.in_arr_dim == 4 and self.out_arr_dim == 2:
                    self.h = self.h + 0j
                    self.h[0, 0, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[1, 0, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[2, 1, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[3, 1, self.half_taps_num + i, i] = self.tap_init_value
                elif self.in_arr_dim == 4 and self.out_arr_dim == 4:
                    self.h[0, 0, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[1, 1, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[2, 2, self.half_taps_num + i, i] = self.tap_init_value
                    self.h[3, 3, self.half_taps_num + i, i] = self.tap_init_value
                else:
                    raise NotImplementedError('Initialization of MIMO type is not implemented')

    def __lr_step__(self):    # update learning rate
        if self.lr_optim == 'constant':
            self.lr = self.lr
        else:
            pass

    def __lms_update__(self, xin, xout, error):
        r"""
            This function updates filter tap coefficients
            by gradient descent using least mean error.
            err: (O, B, N)
            xin: (I, B, M)
            h:   (I, O, M, N)
            Raise: RuntimeError: 'Nan values'
        """

        err = error * xout
        if type(xin) == torch.Tensor:
            gradients = torch.einsum('ibm,obn->iomn', xin, torch.conj(err))
            self.h = self.h + self.lr * gradients
        else:
            gradients = np.einsum('ibm,obn->iomn', xin, np.conj(err))
            self.h = self.h + self.lr * gradients
            if np.sum(np.isnan(self.h)) != 0:
                raise RuntimeError('Nan values')
        
