from IFTS.fiber_simulation import comm_tools as comm
import numpy as np
import torch
from IFTS.fiber_simulation.utils.show_progress import progress_info, progress_info_return
import matplotlib.pyplot as plt

def cma_radius(sym_map, data_mode = 'tensor', **kwargs):
    r"""
        This function calculates all the radius lengths of symbol map.
        Parameters: sym_map:ndarray
                        Symbol map.
                    data_mode:str
                        Data mode. Defaults to 'tensor'.
        Return:     np.sort(r): ndarray
                        Radius lengths array of symbol map.
                    np.max(r): float
                        Maximum radius length of symbol map.
                    np.min(r): float
                        Minimum radius length of symbol map.
    """
    sym_map = comm.normalization.norm_1d(sym_map)    # Normalization
    r = np.unique(np.abs(sym_map))
    if data_mode == 'tensor':
        device = kwargs.get('device', 'cpu')   
        r = torch.from_numpy(r).to(device)
        return r.sort(), r.max(), r.min()
    else:
        return np.sort(r), np.max(r), np.min(r)

def cma_error(x, r):
    r"""
        Pre-train with CMA before performing MMA.
        This function calculates the loss of the CMA algorithm,
        using the maximum radius length.
        Parameters: x:ndarray
                        Input signal sequence.
                    r: float
                        Maximum radius length of symbol map.
        Return:     err: ndarray
                        Error sequence.
                    err_mean: float
                        The average of the error sequence.
    """
    if type(x) == torch.Tensor:
        err = ((r ** 2 - torch.abs(x) ** 2))
        err_mean = torch.mean(err.abs(), dim = 1)
    else:
        err = (r ** 2 - np.abs(x) ** 2)
        err_mean = np.mean(np.abs(err), axis = 1)
    return err, err_mean   

def mma_error(x, r):
    r"""
        After the CMA pre-training, MMA training is performed.
        This function calculates the loss of the MMA algorithm,
        using the radius length according to the principle of proximity.
        Parameters: x:ndarray
                        Input signal sequence.
                    r: float
                        Radius lengths array of symbol map.
    """
    if type(x) == torch.Tensor:
        idx = torch.argmin(torch.abs(r - torch.abs(x.unsqueeze(dim=-1))), axis = -1)
        r_used = r[idx]
    else:
        idx = np.argmin(np.abs(r - np.abs(x[..., np.newaxis])), axis = -1)
        r_used = r[idx]
    return cma_error(x, r_used)    # Call cma_error function to calculate loss

def orthogonalizetaps(h):
    r"""
        Return taps orthogonal to the input taps.
        This only works for dual-polarization signals and
        follows the technique described in _[1] to avoid the CMA pol-demux singularity.
        References:
            ..[1] L. Liu, et al. “Initial Tap Setup of Constant Modulus Algorithm for Polarization De-Multiplexing in
            Optical Coherent Receivers,” in Optical Fiber Communication Conference and National Fiber Optic Engineers Conference
            (2009), paper OMT2, 2009, p. OMT2.
    """
    if h.shape[0] == 2:
        r"""
           2x2:
            [hxx, hyx,
             hxy, hyy]
        """
        if type(h) == np.ndarray:
            # hyy = hxx*
            h[1, 1] = np.conj(h[0, 0]) 
            # hyx = - hxy*
            h[0, 1] = - np.conj(h[1, 0]) 
        else:
            # hyy = hxx*
            h[1, 1] = torch.conj(h[0, 0]) 
            # hyx = - hxy*
            h[0, 1] = - torch.conj(h[1, 0]) 
    elif h.shape[0] == 4:
        if type(h) == np.ndarray:
            # hy_yi = hx_xi*
            h[2, 1] = np.conj(h[0, 0])
            # hy_yq = hx_xq*
            h[3, 1] = np.conj(h[1, 0])
            # hy_xi = -hx_yi*
            h[0, 1] = -np.conj(h[2, 0])
            # hy_xq = -hx_yq*
            h[1, 1] = -np.conj(h[3, 0])
        else:
            # hy_yi = hx_xi*
            h[2, 1] = torch.conj(h[0, 0])
            # hy_yq = hx_xq*
            h[3, 1] = torch.conj(h[1, 0])
            # hy_xi = -hx_yi*
            h[0, 1] = -torch.conj(h[2, 0])
            # hy_xq = -hx_yq*
            h[1, 1] = -torch.conj(h[3, 0]) 
    return h 

def cma_main(x, y, para):
    r"""
        The main function of the adaptive filtering algorithm.
        The input signals are sequentially normalized, filled, trained and tracked.
        Parameters: x,y:ndarray
                        Input signals of two polarization states.
        Return:     x_out, y_out:ndarray
                        Two signals of two polarization states after adaptive filtering.
    """
    @progress_info(total_num = para.train_tot_num,\
        infor_print = para.infor_print, discription = ' --MIMO_'+para.mode+' Training')
    def train_main(**kwargs):
        r"""
            Training function for adaptive filtering.
            The first iteration uses the CMA algorithm,
            followed by the MMA algorithm iteration.
             Parameters: x,y:ndarray
                            Input signals of two polarization states.
        """
        pbar = kwargs.get('pbar', 0)
        i_total = 0
        for i_iter in range(para.train_epoch):
            # Determine whether orthogonalization is performed
            if para.h_ortho and para.train_epoch>1\
                 and i_iter == para.pre_train_iter + 1:
                para.h = orthogonalizetaps(para.h)

            # Determine whether cma is performed
            if para.algo_type == 'mma' and para.cma_pretrain\
                 and para.train_epoch>1\
                     and i_iter <= para.pre_train_iter:
                    cma_train = 1
            elif para.algo_type == 'cma':
                cma_train = 1
            else:
                cma_train = 0
            # apply adaptive filter
            for i_sym in range(para.train_block_num):
                block_idx = i_sym * para.sam_num_in_block + para.block_idx
                x_in = x[block_idx]
                y_in = y[block_idx]
                out, xin = para.apply_filter([x_in, y_in])
                # cma/mma error
                if cma_train:
                    err, err_mean = cma_error(out, para.cma_radius[para.radius_idx])
                    para.err[i_total] = err_mean
                else:
                    err, err_mean = mma_error(out, para.cma_radius)
                    para.err[i_total] = err_mean
                # update
                para.step(xin, out, err)
                i_total += 1
                if type(pbar) != type(0):
                    pbar.update(1)
    
    @progress_info_return(total_num = para.track_block_num,\
        infor_print = para.infor_print, discription = ' --MIMO_'+para.mode+' Tracking')
    def track_main(**kwargs):
        r"""
            Tracing function for adaptive filtering.
            After the training phase, the signal is fed in batches for filtering.
            Return:     x_out, y_out:ndarray
                        Two signals of two polarization states after adaptive filtering.
        """
        pbar = kwargs.get('pbar', 0)
        if para.data_mode == 'numpy':
            x_out, y_out = np.zeros(para.out_sym_tot_num) + 0j, np.zeros(para.out_sym_tot_num) + 0j
        else:
            x_out, y_out = torch.zeros(para.out_sym_tot_num, device = para.device) + 0j,\
                 torch.zeros(para.out_sym_tot_num, device = para.device) + 0j
        for i_sym in range(para.track_block_num):
            # Tracking signal in batches
            block_idx = i_sym * para.sam_num_in_block + para.block_idx
            x_in = x[block_idx]
            y_in = y[block_idx]
            # apply adaptive filter
            out, xin = para.apply_filter([x_in, y_in])
            x_out[i_sym*para.block_size:(i_sym+1)*para.block_size] = out[0].reshape(-1)
            y_out[i_sym*para.block_size:(i_sym+1)*para.block_size] = out[1].reshape(-1)
            # cma/mma error
            if para.algo_type == 'cma':
                err, _ = cma_error(out, para.cma_radius[para.radius_idx])
            else:
                err, _ = mma_error(out, para.cma_radius)
            # update
            para.step(xin, out, err)
            if type(pbar) != type(0):
                pbar.update(1)
        return x_out, y_out

    x = para.data_mode_convert(x)
    y = para.data_mode_convert(y)
    # normalization
    x = comm.normalization.norm_1d(x, sig_p = 1, mean = 0)
    y = comm.normalization.norm_1d(y, sig_p = 1, mean = 0)
    # padding
    x = para.padding(x)
    y = para.padding(y)
    train_main()    # train
    x_out, y_out = track_main()    # track
    return x_out, y_out