import torch
import numpy as np
import IFTS.fiber_simulation as f_simu
from IFTS.fiber_simulation import comm_tools
from IFTS.fiber_simulation.comm_tools.resample.resample_fft import resample as resample_fft
from IFTS.fiber_simulation.comm_tools.normalization import norm_1d

def front_end_processing(rx_sig, para):
    r"""
        Filter, Downsampling, IQ Balancing, and Normalization
    """
    print_flag = para.infor_print

    # normalization
    if type(rx_sig[0]) == np.ndarray:
        rx_sig = np.array(rx_sig).reshape((4, -1))
        rx_sig = rx_sig / np.sqrt(np.var(rx_sig, axis = -1, keepdims=True))
        rx_sig = rx_sig - np.mean(rx_sig, axis = -1, keepdims=True)
    else:
        rx_sig = torch.stack((rx_sig[0].real, rx_sig[0].imag,\
            rx_sig[1].real, rx_sig[1].imag), dim = 0)
        rx_sig = rx_sig / torch.sqrt(torch.var(rx_sig, dim = -1, keepdims=True))
        rx_sig = rx_sig - torch.mean(rx_sig, dim = -1, keepdims=True)

    # low_pass filter
    if para.lpf:
        rx_sig = para.lpf_obj(rx_sig)
        if print_flag:
            print(' --LPF finish--')
    rx_sam_x = rx_sig[0] + 1j * rx_sig[1]
    rx_sam_y = rx_sig[2] + 1j * rx_sig[3]

    # Down sample
    rx_sam_x = resample_fft(rx_sam_x, para.block_upsam/para.sam_now)
    rx_sam_y = resample_fft(rx_sam_y, para.block_upsam/para.sam_now)
    sam_now = para.block_upsam 
    if print_flag:
        print(' --Downsample to {:.1f} sam/sym--'.format(sam_now))

    # IQ Imbalance
    if para.iq_balance:
        rx_sam_x = para.iq_balance_obj(rx_sam_x)
        rx_sam_y = para.iq_balance_obj(rx_sam_y)
        if print_flag:
            print(' --IQ Imbalancing finish--')

    # cheak frame
    rx_sam_num = int(sam_now * para.frame_sym_num * para.frame_num)
    rx_sam_x = rx_sam_x[0: rx_sam_num]
    rx_sam_y = rx_sam_y[0: rx_sam_num]

    # normalization
    rx_sam_x = norm_1d(rx_sam_x, 1)
    rx_sam_y = norm_1d(rx_sam_y, 1)
    rx_sig = [rx_sam_x, rx_sam_y]

    return rx_sig, sam_now

def mul_sam_comp(rx_sig, sym_map, para, plot_para):
    r"""
        Chromatic dispersion compensation, Adaptive filter
    """
    print_flag = para.infor_print
    plot_flag = para.fig_plot
    if para.mimo and para.nPol == 1:
        raise AttributeError("MIMO Error: Expect two polarization while you use the single polarization !")

    # Chromatic Dispersion Compensation
    if para.cdcom: 
        rx_sig = para.cdcom_obj(rx_sig)
        if print_flag:
            print(' --CDC finish--')
        if plot_flag:
            plot_para.scatter_plot_nPol(rx_sig, sam_num = para.sam_now,\
                name = 'Constellation_After_CDC', set_c = 1, s = 4)

    # adaptive filter
    if para.mimo:
        rx_sig = para.mimo_obj(x = rx_sig[0], y = rx_sig[1], sym_map = sym_map)
    else:
        rx_sam_x = resample_fft(rx_sig[0], 1/para.sam_now)
        rx_sam_y = resample_fft(rx_sig[1], 1/para.sam_now)
        rx_sig = [rx_sam_x, rx_sam_y]
    para.sam_now = 1
    return rx_sig, para
            
def single_sam_comp(rx_sig, sym_map, tx_sig, para, plot_para):
    r"""
        Synchronization, Carrier phase estimation
    """
    print_flag = para.infor_print
    plot_flag = para.fig_plot

    # Synchronization
    if para.synchronization:
        rx_sig, corr, frame_results = para.synchron_obj(tx_sig, rx_sig)
        # Plot
        if plot_flag:
            plot_para.scatter_plot_nPol(rx_sig, sam_num = para.sam_now,\
                name = 'Constellation_After_SYM', set_c = 0, s = 4)
            plot_para.corr_plot(corr, name = 'synchronization')
        if print_flag:
            if para.synchron_obj.mode == '4x4':
                print(' --Tx xi_corr:{:.2f} Tx xq_corr:{:.2f} Tx yi_corr:{:.2f} Tx yq_corr:{:.2f}--'.format(\
                    frame_results[0], frame_results[1], frame_results[2], frame_results[3]))
            else:
                print(' --Tx x_corr:{:.2f} Tx y_corr:{:.2f} --'.format(\
                    frame_results[0], frame_results[1]))
            print(' --synchronization finished--')

    # Carrier phase estimation
    if para.cpe:
        rx_sig, phase_hat = para.cpe_obj(rx_sig, tx_sig, sym_map = sym_map) 
        # Plot
        if plot_flag:
            if print_flag:
                print(' --CPE plot--')
            plot_para.scatter_plot_nPol(rx_sig, sam_num = para.sam_now,\
                name = 'Constellation_After_CPE', set_c = 0, s = 4)
        if print_flag:
            print(' --CPE finished--')

    return rx_sig, para

def rx(rx_sig, sym_map, tx_sig, para, plot_para):
    r"""
        This function executes signal DSP in receiver.
        Parameters: rx_sig: list
                        Signal sequence of receiver.
                    sym_map: ndarray
                        Symbol map.
                    tx_sig: list
                        Signal sequence of transmitter.
                    para:
                        Receiver parameters, initialized in rxsignal_para.
                    plot_para:
                        plot parameters, initialized in sigplot_para.
        Return:     rx_sig_single: list
                        The signal sequence of receiver after DSP.
    """

    """
       Frontend processing
       Multi_sample Compensation
       Single_sample Compensation
    """

    print_flag = para.infor_print
    plot_flag = para.fig_plot
    save_flag = para.save_data
    data_mode = para.rx_data_mode

    # Get started
    if save_flag:
        pol_name = ['x', 'y']
        for i_p in range(para.nPol):
            name = 'before_dsp_'+ pol_name[i_p]
            para.save_data_func(rx_sig[i_p], name)
    if print_flag:
        print ("Frontend Start ......")

    # Frontend processing
    rx_sig, para.sam_now = front_end_processing(rx_sig, para)
    if plot_flag:
        plot_para.scatter_plot_nPol(rx_sig, sam_num = para.sam_now,\
            name = 'Constellation_After_Frontend', set_c = 1, s = 4)
        plot_para.sam_rate = para.sam_now * para.sym_rate
        plot_para.psd_nPol(rx_sig, name = 'PSD_After_Frontend', sign = para.sym_rate*1.1/2)
    if print_flag:
        print ("Multi-sample Compensation Start ......")

    # Multi_sample Compensation
    rx_sig_mul, para = mul_sam_comp(rx_sig, sym_map, para, plot_para)
    if plot_flag:
        if print_flag:
            print(' --MIMO plot--')
        plot_para.scatter_plot_nPol(rx_sig_mul, sam_num = para.sam_now,\
            name = 'Constellation_After_LC', set_c = 1, s = 4)
        if para.mimo:
            plot_para.loss_plot_nPol(para.mimo_obj.err[:,0,0], para.mimo_obj.err[:,1,0], name = 'MIMO_Loss')
            plot_para.firtap_plot_nPol(para.mimo_obj, name = 'MIMO_Tap') 
    if save_flag:
        pol_name = ['x', 'y']
        for i_p in range(para.nPol):
            name = 'after_cdc_'+ pol_name[i_p]
            para.save_data_func(rx_sig_mul[i_p], name)
    if print_flag:
        print ("Single-sample Compensation Start ......")

    # Single_sample Compensation
    rx_sig_single, para = single_sam_comp(rx_sig_mul, sym_map, tx_sig, para, plot_para)
    if save_flag:
        pol_name = ['x', 'y']
        for i_p in range(para.nPol):
            name = 'after_dsp_'+ pol_name[i_p]
            para.save_data_func(rx_sig_single[i_p], name)

    return rx_sig_single

def rx_awgn(rx_sig, sym_map, tx_sig, para, plot_para):
    print_flag = para.infor_print
    rx_sam_x = rx_sig[0]
    rx_sam_y = rx_sig[1]
    if type(rx_sig[0]) == np.ndarray:
        rx_sig = np.array(rx_sig).reshape((4, -1))
        rx_sig = rx_sig / np.sqrt(np.var(rx_sig, axis = -1, keepdims=True))
        rx_sig = rx_sig - np.mean(rx_sig, axis = -1, keepdims=True)
    else:
        rx_sig = torch.stack((rx_sig[0].real, rx_sig[0].imag,\
            rx_sig[1].real, rx_sig[1].imag), dim = 0)
        rx_sig = rx_sig / torch.sqrt(torch.var(rx_sig, dim = -1, keepdims=True))
        rx_sig = rx_sig - torch.mean(rx_sig, dim = -1, keepdims=True)

    # low_pass filter
    if para.lpf:
        rx_sig = para.lpf_obj(rx_sig)
        if print_flag:
            print(' --LPF finish--')
    rx_sam_x = rx_sig[0] + 1j * rx_sig[1]
    rx_sam_y = rx_sig[2] + 1j * rx_sig[3]
    
    rx_sam_x = resample_fft(rx_sam_x, 1/para.sam_now)
    rx_sam_y = resample_fft(rx_sam_y, 1/para.sam_now)
    rx_sig = [rx_sam_x, rx_sam_y]
    para.sam_now = 1
    if print_flag:
        print(' --Downsample to {:.1f} sam/sym--'.format(para.sam_now))
    rx_sig_single, para = single_sam_comp(rx_sig, sym_map, tx_sig, para, plot_para)
    return rx_sig_single