import numpy as np
import scipy.signal as signal
from matplotlib import pyplot as plt
from ...comm_tools.fftconv import conv

def corr_real(tx_sig, rx_sig):
    r"""
        This function calculates the correlation of real numbers.
        Parameters: tx_sig：ndarray
                        Signal sequence of transmitter.
                    rx_sig：ndarray
                        Signal sequence of receiver.
        Return:     corr_result：ndarray
                        Correlation result.
    """
    tx_sig = np.flip(tx_sig, axis = 0).reshape((1, -1))
    tx_sig = np.repeat(tx_sig, rx_sig.shape[0], axis = 0)
    corr_result = conv(rx_sig, tx_sig, mode="full", axis = -1)
    return corr_result

def do_4x4(tx, rx):
    r"""
        This function executes the 4x4 correlation.
        Parameters: tx：ndarray
                        Signal sequence of transmitter.
                    rx：ndarray
                        Signal sequence of receiver.
        Return:     corr_result：ndarray
                        Correlation tensors formed by 4 path correlation of the transceiver.
        Raise:
                    ValueError: 'The first dimension of inputs are not 4'
    """
    if tx.shape[0] != 4 or rx.shape[0] != 4:
        raise ValueError('The first dimension of inputs are not 4')
    out_sym_num = tx.shape[1] + rx.shape[1] - 1
    corr_results = np.zeros((4, 4, out_sym_num))
    for i in range(4):
        corr_results[i] = corr_real(tx_sig=tx[i], rx_sig=rx)
    return corr_results
