import numpy as np
import torch
from ... base.base_dsp import DSP_Base_Module
from . import correlation as corr

class synchron(DSP_Base_Module):
    
    def __init__(self, corr_num, frame_num, frame_size, mode = '4x4', data_mode = 'tensor', *args, **kwargs):
        r"""
            Initialization function of the synchronization moduLe.
            This function initializes the synchronization parameters.
            Parameters: corr_num：int
                            Correlation number，which is maximum delay allowed when synchronizing.
                        frame_num：int
                            Frame numbers, defaults to 1.
                        mode：str
                            Algorithm mode, defaults to '4x4'.
                        data_mode：str
                            Data mode. Defaults to numpy.
        """
        super().__init__()
        self.mode = mode
        self.corr_num = corr_num
        self.frame_size = frame_size
        self.frame_num = frame_num
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')

    def init(self, *args, **kwargs):
        # income parameters
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
        
    def forward_pass(self, tx_sig, rx_sig):
        r"""
            This function executes the synchronization module.
            Parameters: rx_sig：list
                            Signal sequence of receiver.
                        tx_sig：list
                            Signal sequence of transmitter.
            Return:     rx_sig: list
                            The signal sequence of receiver after synchronization.
                        corr_results：ndarray
                            Correlation tensors formed by 4 path correlation of the transceiver.
                        frame_corr_abs：ndarray
                            Maximum correlation absolute value per frame.
        """
        if self.mode == '4x4':
            tx = self.__4x4_init__(tx_sig[0][0:self.corr_num], tx_sig[1][0:self.corr_num])
            rx = self.__4x4_init__(rx_sig[0], rx_sig[1])
            rx_sig, corr_results, frame_corr_abs = self.__4x4_main__(tx, rx)
        return rx_sig, corr_results, frame_corr_abs

    def __4x4_init__(self, x, y):
        r"""
            This function segregates the real and imaginary parts of input two polarized signals,
            concatenates them into 4-dimensional signal tensors.
            Parameters: x,y：list
                            Input two polarized signals.
            Return:     sigout: ndarray
                            Signal after processing.
        """
        x = self.data_mode_convert(x).reshape((1, -1))
        y = self.data_mode_convert(y).reshape((1, -1))
        if self.data_mode == 'numpy':
            sigout = np.concatenate((x.real, x.imag, y.real, y.imag), axis = 0)
        else:
            sigout = torch.cat((x.real, x.imag, y.real, y.imag), dim = 0)
        return sigout

    def __4x4_main__(self, tx, rx):
        r"""
            The main function of 4x4 correlation and synchronization.
            Parameters: tx：ndarray
                            Signal sequence of transmitter.
                        rx：ndarray
                            Signal sequence of receiver.
            Return:     [out_x, out_y]：list
                            The signal sequence of receiver after synchronization.
                        corr_results：ndarray
                            Correlation tensors formed by 4 path correlation of the transceiver.
                        frame_corr_abs：ndarray
                            Maximum correlation absolute value per frame.
        """
        corr_results = corr.do_4x4(tx = tx, rx = rx)
        corr4x4_abs = np.abs(corr_results)
        corr2x4_abs = np.zeros((2, 4, corr4x4_abs.shape[-1]))
        corr2x2_abs = np.zeros((2, 2, corr4x4_abs.shape[-1]))
        # calculate 2x2 and 2x4 correlation
        """
            corr4x4_abs size = (4, 4, N):
                    row: rx, column: tx:
                                rx_xi   rx_xq   rx_yi   rx_yq
                        tx_xi     a       b       c       d
                        tx_xq     e       f       g       h
                        tx_yi     i       j       k       l
                        tx_yq     m       n       o       p
            corr2x4_abs size = (2, 4, N)
                                rx_1   rx_2   rx_yi   rx_yq
                        tx_x     a+f    b+e    c+h     d+g
                        tx_y     i+n    j+m    k+p     l+o
        """
        for i in range(2):
            for j in range(2):
                corr2x4_abs[i, 2*j] = corr4x4_abs[2*i, 2*j] + corr4x4_abs[2*i+1, 2*j+1]
                corr2x4_abs[i, 2*j+1] = corr4x4_abs[2*i, 2*j+1] + corr4x4_abs[2*i+1, 2*j]
                corr2x2_abs[i, j] = corr2x4_abs[i, 2*j] + corr2x4_abs[i, 2*j + 1]
        max_idx = np.argsort(corr2x2_abs, axis = -1)[..., -self.frame_num:]
        if self.frame_num > 3:
            max_idx = np.sort(max_idx)[..., 1:self.frame_num-1]
        frame_num = max_idx.shape[-1]
        # find x and y aixs
        symch_xy = np.mean(np.take_along_axis(corr2x2_abs, max_idx, axis = -1), axis = -1)
        symch_xy = np.argmax(symch_xy, axis = -1)
        xy_idx = symch_xy
        # update the 2x2 correlation
        max_xy_idx = np.array([max_idx[0,symch_xy[0]], max_idx[1,symch_xy[1]]])
        corr2x2_abs[0] = corr2x4_abs[0, symch_xy[0]*2: symch_xy[0]*2+2]
        corr2x2_abs[1] = corr2x4_abs[1, symch_xy[1]*2: symch_xy[1]*2+2]
        # find i and q aixs
        symch_iq = np.take_along_axis(corr2x2_abs, np.repeat(max_xy_idx[:,None], 2, axis = 1), axis = -1)
        peak = np.argmax(symch_iq.reshape((2, -1)), axis = -1)
        iq_idx = (peak/frame_num).astype(np.int64)
        # find frame aixs
        data_idx = max_xy_idx[((0,1), peak % frame_num)]
        corr_idx = np.zeros((4, 3)).astype(np.int64)
        corr_idx[0] = np.array([0, xy_idx[0]*2+iq_idx[0], data_idx[0]])
        corr_idx[1] = np.array([1, corr_idx[0,1]+1-2*(corr_idx[0,1]%2), data_idx[0]])
        corr_idx[2] = np.array([2, xy_idx[1]*2+iq_idx[1], data_idx[1]])
        corr_idx[3] = np.array([3, corr_idx[2,1]+1-2*(corr_idx[2,1]%2), data_idx[1]])
        frame_corr_abs = np.array([corr_results[corr_idx[i,0], corr_idx[i,1], corr_idx[i,2]] for i in range(4)])
        # find max correlation point and roll signal sequence
        sign = np.sign(frame_corr_abs)
        rx = rx[corr_idx[:, 1]] * sign[:,None]
        roll_num = corr_idx[:, 2] - self.corr_num + 1
        out_x = np.roll(rx[0], -roll_num[0])[0:self.frame_size] + 1j*np.roll(rx[1], -roll_num[1])[0:self.frame_size] 
        out_y = np.roll(rx[2], -roll_num[2])[0:self.frame_size] + 1j*np.roll(rx[3], -roll_num[3])[0:self.frame_size]
        frame_corr_abs = np.abs(frame_corr_abs[corr_idx[:, 1]])
        return [out_x, out_y], corr_results, frame_corr_abs