import numpy as np
from ...fiber_simulation.comm_tools import calculation as calcu
from ...fiber_simulation.sig.data_gen import sym_gen
from ...fiber_simulation.sig.perf_calculation import ber, mi, q_factor, ser
from ...fiber_simulation.sig.modulation.modulation import qam
from ...fiber_simulation.comm_tools import base_conversion as bc

def sig_tx(para, seed = -1):

    if para.modulation:
        sym_map = qam(para.bits_per_sym, gs = para.geometric_shaping)

    x = sym_gen(para.bit_num_per_pol, seed = seed, bit_load = para.bit_load, random_class = para.random_type)
    sym = bc.bit2sym(x, para.bits_per_sym).reshape(-1)
    emp_prob = calcu.empirical_distribution(sym, para.class_num)
    sym_map = sym_map / (np.sum(emp_prob * np.abs(sym_map) ** 2)) ** 0.5
    tx_sym = sym_map[sym]
    return x, tx_sym, sym_map, sym

def sig_rx(rx_sig, sym_map, bit_seq, integer_seq, para, **kwargs):
    pol_name = ["X","Y"]
    ber_array, b2q_array, c2q_array = np.zeros((para.nPol + 1)), np.zeros((para.nPol + 1)), np.zeros((para.nPol + 1))
    gmi_array, mi_array = np.zeros((para.nPol + 1)), np.zeros((para.nPol + 1))
    sym_err, sym_err_idx, sig_int = [], [], []
    plot_flag = para.fig_plot
    if plot_flag:
        if "plot_para" in kwargs:
            plot_para = kwargs['plot_para']
        else:
            plot_flag = 0
    sig_nosie = []
    for i_p in range(para.nPol):
        r""" Remove padding number for bit_seq and tx_sig
        The padding number of rx_sig is removed in CPE at rx_main()
        """
        rx_sym_num = rx_sig[i_p].shape[0] - 2*para.front_sym_num
        n_bits = para.bits_per_sym
        y       = rx_sig[i_p][para.front_sym_num: para.front_sym_num + rx_sym_num]
        x_int   = integer_seq[i_p][para.front_sym_num: para.front_sym_num + rx_sym_num]
        x       = sym_map[x_int]
        bit     = bit_seq[i_p].reshape(-1)
        bit = bit[para.front_sym_num*para.bits_per_sym:\
            (para.front_sym_num + rx_sym_num)*para.bits_per_sym]
        r""" Scale received signal
        Received: Y = A * (X + N)
        Find the A and scale the signal
        """
        scale = calcu.optimze_scale(x, y) # y = scale * (x + n)
        if 1/scale > 10:
            scale = 1.0
        y = 1 / scale * y
        sig_nosie.append(y-x)
        r""" Calculate the noise variance and empirical input distribution
        """
        n0, var, sig_p, snr, probability = calcu.noise_var_esti_normfit(x_int, y, para.class_num)
        r""" Calculate the LLR (Log-Likelihood Rate)
        """
        bitwise_mapping = bc.sym2bit(np.arange(0, para.class_num), n_bits)
        c0, c1 = calcu.get_soft_data(y, n_bits, bitwise_mapping)
        llr = calcu.calcu_llr_matrix(y, n_bits, sym_map, c0, c1, var).reshape((-1, n_bits))
        r""" Demodulate and calculate BER according to constellations or llrs
        """
        bit_hat = para.demod_obj.get_bit(llr, is_llr = True)
        _, bit_err_rat = ber.ber(bit, bit_hat)
        r""" Calculate MI or GMI
        """
        gmi_value = mi.generalized_mutual_information_mc(llr, bit.reshape((-1, para.bits_per_sym)),\
            sym_map, probability = probability) 
        mi_value  = mi.mutual_information_mc(x, y, x_int, sym_map,\
            probability = probability, n0 = var) 
        r""" Calculate Q factor
        """
        b2q = q_factor.ber2q(bit_err_rat)
        c2q = snr
        r""" Calculate the SER
        """
        para.demod_obj.mode = 'int'
        sym_hat = para.demod_obj(y, sym_map = sym_map) 
        sym_err_rat, sym_error, sym_error_idx = ser.ser(x_int, sym_hat, y)
        r""" Save the results and print them
        """
        if plot_flag:
            plot_para.err_analysis_plot(y, x_int, sym_error, sym_error_idx, sym_map, i_p)
        gmi_array[i_p] = gmi_value
        gmi_array[-1] += gmi_value / para.nPol
        mi_array[i_p] = mi_value
        mi_array[-1] += mi_value / para.nPol
        b2q_array[i_p] = b2q
        c2q_array[i_p] = c2q
        c2q_array[-1] += c2q / para.nPol
        ber_array[i_p] = bit_err_rat
        ber_array[-1] += bit_err_rat / para.nPol
        sym_err.append(sym_error)
        sym_err_idx.append(sym_error_idx)
        sig_int.append(x_int)
        if para.infor_print:
            print('{}pol: SER {:.6f} BER {:.6f} Qfactor {:.6f} dB ESNR {:.6f} dB MI {:.4f} GMI {:.4f}'.format(\
                pol_name[i_p], sym_err_rat, bit_err_rat, b2q, c2q, mi_value, gmi_value))
    b2q_array[-1] = q_factor.ber2q(ber_array[-1])
    para.ber_array, para.b2q_array, para.c2q_array = ber_array, b2q_array, c2q_array
    para.gmi_value, para.mi_value = gmi_value, mi_value
    para.sym_err = sym_err
    para.sym_err_idx = sym_err_idx
    return para, sig_nosie, sig_int
