import numpy as np
from IFTS.fiber_simulation.comm_tools import calculation as calcu
from IFTS.fiber_simulation.comm_tools import base_conversion as bc

def mutual_information_mc(x, y, x_int, constellations, n0, probability, ** kwargs):
    r''' 
    The symbol based memoryless mutual information (MI) is calculated using 
    Monte Carlo method. If the channel is not a circularly symmetric Gaussian 
    channel, the lower bound of MI, that is, the achievable information rate, 
    is obtained. In order to get a reliable estimation, the noise variance N0 
    should be obtained from the sample first. Then, the MI on different symbol 
    sequences is estimated using the previously obtained noise variance.
    Parameters:
    x: array_like
    Transmission data sequence with length N, complex form
    y: array_like
    Received data sequence with length N, complex form
    x_int: array_like
    Integer symbol sequence corresponding to the sending symbol with length N
    constellations: array_like
    signal constellation
    n0: float
    Noise variance
    probability: array_like
    Probability distribution corresponding to transmitted data

    Returns:
    mi: float
    Mutual information calculated by Monte Carlo method based on Gaussian distribution assumption
    '''
    m = constellations.shape[0]
    if type(n0) == np.ndarray:
        mi = 0
        for i in range(m):
            idx = np.where(x_int == i)[0]
            qYonX = (1/(np.pi*n0[i])*np.exp((-np.abs(y[idx]-x[idx])**2)/n0[i]))
            qY = np.sum(probability[None]*(1/(np.pi*n0[i])\
                *np.exp((-np.abs(y[idx,None]-constellations)**2)/n0[i])), axis=-1)
            mi += np.mean(np.log2(qYonX / qY)) *probability[i]
    else:
        qYonX = (1 / (np.pi * n0) * np.exp((- np.abs(y - x) ** 2) / n0))
        qY =  np.sum(probability[None] * (1 / (np.pi * n0)\
            *np.exp((- np.abs(y[:,None] - constellations) ** 2) / n0)), axis=-1)
        mi = np.mean(np.log2(qYonX / qY))
    return mi

def generalized_mutual_information_mc(llr, tx_bit, sym_map, probability, ** kwargs):
    r""" 
    The sum of memoryless mutual information based on bits is calculated using 
    circular symmetric Gaussian noise statistics, which is also called generalized 
    mutual information (GMI). Unlike symbol based MI, it is an achievable rate for 
    receivers with binary decoding and no iteration between the demapper and decoder.
    Parameters:
    llr: array_like
    Log likelihood ratio
    tx_bit: array_like
    Data sequence with length N, complex form
    sym_map: array_like
    Constellation distribution of input signal
    probability: array_like
    Probability distribution corresponding to input data

    Returns:
    gmi: float
    Based on the assumption of Gaussian distribution, the generalized mutual information is calculated
    """
    m = sym_map.shape[0]
    n_bits = int(np.log2(m))
    mi_per_bit = np.zeros((n_bits))
    tx_bit = tx_bit.astype(np.float64)
    for i_b in range(n_bits):
        mi_per_bit[i_b] = - np.mean(np.log2(1 + np.exp((2 * tx_bit[:, i_b]-1) * llr[:, i_b])))
    tx_entropy = calcu.calcu_entropy(probability)
    gmi = tx_entropy + np.sum(mi_per_bit)
    gmi = max(gmi, 0)

    return gmi




    