import torch
import numpy as np
import scipy
import sys
from scipy.special import erfcinv
from scipy.special import erfc
from scipy.optimize import fmin,fminbound
from sklearn.covariance import EmpiricalCovariance

def papr(x):
    '''
    Calculate the peak to average power ratio according to the input signal
    Parameters:
    x: array_like
    Input signal sequence

    Returns:
    papr_value: float
    Peak to average power ratio calculated from signal peak power and average power
    '''
    if type(x) == np.ndarray:
        papr = np.max(np.abs(x) ** 2) / np.mean(np.abs(x) ** 2)
        papr_value = 10 * np.log10(papr)
        return papr_value
    else:
        papr = torch.max(torch.abs(x) ** 2) / torch.mean(torch.abs(x) ** 2)
        papr_value = 10 * torch.log10(papr).item()
        return papr_value

def ISI_induced_by_dispersion(sig_bw, sym_time, D, wavelength, length):
    """
    Calculate the number of symbols affected by inter symbol crosstalk caused by dispersion
    Parameters:
    sig_bw: float
    Signal bandwidth (GHz)
    sym_time: float
    Symbol duration (ps)
    D: float
    Dispersion coefficient (ps/(nm*km))
    wavelength: float
    Signal wavelength (nm)
    length: float
    Signal transmission distance (km)

    Returns:
    num: int
    Number of symbols affected by inter symbol crosstalk caused by dispersion
    """
    C = 299792458                # speed of light (m/s)
    lam = wavelength
    l = length
    f0 = C / lam
    dlam = C * (1/(f0 - sig_bw) - 1/(f0 + sig_bw) )
    # dlam = C * (1/(f0) - 1/(f0 + sig_bw) )
    dt = np.abs(l * dlam * D)   
    num = np.ceil(dt / sym_time)
    return num


def digital_freq(fft_num, sam_rate, data_mode = 'numpy', device = 'cpu'):
    """
    Return the Discrete Fourier Transform sample frequencies.
    The returned float array `f` contains the frequency bin centers in cycles
    per unit of the sample spacing (with zero at the start).  For instance, if
    the sample spacing is in seconds, then the frequency unit is cycles/second.

    Given a window length `n` and a sample spacing `d`::
    f = [0, 1, ...,   n/2-1,     -n/2, ..., -1] / (d*n)   if n is even
    f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n)   if n is odd

    Parameters:
    fft_num : int
        FFT window length.
    sam_rate : scalar
        Sampling rate
    data_mode: str, optional
        Data format, The default is numpy
    device: str, optional
        The deployed device. The default is CPU

    Returns:
    f : array_like
        Array of length `n` containing the sample frequencies.
    """
    if data_mode == 'numpy':
        if fft_num % 2 == 0:
            num = fft_num / 2
            f = np.arange(- num , num) * sam_rate / fft_num
        else:
            num = (fft_num - 1) / 2
            f = np.arange(- num, num + 1) * sam_rate / fft_num
    else:
        if fft_num % 2 == 0:
            num = fft_num / 2
            f = torch.arange(- num , num, device=device) * sam_rate / fft_num
        else:
            num = (fft_num - 1) / 2
            f = torch.arange(- num, num + 1, device=device) * sam_rate / fft_num
    return f

def optimze_scale(x, y):
    """
    Find the best linear regression coefficient satisfying y=a(x+n)
    Parameters:
    x: array_like
    Input argument x
    y: array_like
    Input dependent variable y

    Returns:
    scale: float
    The best linear regression coefficient found
    """
    def funct(scale):
        loss = np.mean(np.abs(y - scale * x)**2)
        return loss
    scale = fminbound(funct, 0, 2)
    return scale

def normfit(data, confidence=0.95, ci = 0):
    ''' 
    Carry out normal distribution fitting according to the input sample 
    data, and return the parameter estimation of normal distribution
    Parameters:
    data: array_like
    Input data
    confidence: float, optional
    Confidence, the default value is 95%
    ci: bool, optional
    Determines whether to output interval estimates. The default value is 0 
    and not output

    Returns:
    m: float
    the estimated mean value 
    var: float
    the estimated variance
    '''    
    n = len(data)
    m, se = np.mean(data), scipy.stats.sem(data) # mean, standard error = (std/sqrt(n))
    # for small sample populations (N < 100 or so),
    # it is better to look up z in Student t's distribution instead of in the normal distribution
    # 2-side,and two tails are equal (like normal distribution)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
    var = np.var(data, ddof=1) # ddof needs to be 1 to match matlab implementaton
    # two-side CI -> from alpha=(1-confidence)/2 ~ 1-alpha, here is 0.025~0.975
    # two tails are not equal   
    if ci:
        varCI_upper = var * (n - 1) / (scipy.stats.chi2.ppf((1-confidence) / 2, n - 1))
        varCI_lower = var * (n - 1) / (scipy.stats.chi2.ppf(1-(1-confidence) / 2, n - 1))
        sigma = np.sqrt(var)
        sigmaCI_lower = np.sqrt(varCI_lower)
        sigmaCI_upper = np.sqrt(varCI_upper)

        return m, sigma, [m - h, m + h], [sigmaCI_lower, sigmaCI_upper]
    else:
        return m, var

def qfunc(x):
    '''
    QFUNC  Q function.
    Y = QFUNC(X) returns 1 minus the cumulative distribution function of the 
    standardized normal random variable for each element of X.  X must be a real
    array. The Q function is defined as:
    Q(x) = 1/sqrt(2*pi) * integral from x to inf of exp(-t^2/2) dt
    It is related to the complementary error function (erfc) according to
    Q(x) = 0.5 * erfc(x/sqrt(2))
    Parameters:
    x: array_like
    Input array

    Returns:
    y: array_like
    Calculated cumulative distribution function
    '''
    y = 0.5 * erfc(x / np.sqrt(2))
    
    return y

def calcu_entropy(p):
    '''
    Calculate the information entropy according to the input probability distribution. 
    The calculation formula follows the definition of information entropy
    Parameters:
    p: array_like
    Input probability distribution

    Returns:
    h: float
    Calculated information entropy
    '''
    h = - np.sum(p * np.log2(p))
    return h

def berawgn(EbNo, mode, M):
    '''
    Bit error rate (BER) and symbol error rate (SER) for uncoded 
    AWGN channels.The function returns the BER for PAM or QAM over 
    anuncoded AWGN channel with coherent demodulation.
    Parameters:
    EbNo: float
    bit energy to noise power spectral density ratio (in dB)
    mode: str
    modulation type, either 'pam' or 'qam'
    M: int
    alphabet size, must be a positive integer power of 2

    Returns:
    ber: float
    bit error rate
    ser: float
    symbol error rate
    '''
    if M > 1:
        k = np.log2(M)
    else:
        k = 0.5
    EbNoLin = 10 ** ( EbNo / 10)
    if mode == 'qam':
        if np.ceil(k / 2) == k / 2 :
            # k is even - square QAM
            ser = 4 * (np.sqrt(M) - 1) / np.sqrt(M) * qfunc(np.sqrt(3 * k / (M - 1) * EbNoLin)) \
                    - 4 * ((np.sqrt(M) - 1) / np.sqrt(M)) ** 2 * (qfunc(np.sqrt(3 * k / (M - 1) * EbNoLin))) ** 2
            if M == 4:
                ber = qfunc(np.sqrt( 2 * EbNoLin))
            elif M == 16:
                ber = 3 / 4 * qfunc(np.sqrt(4 / 5 * EbNoLin)) \
                    + 1 / 2 * qfunc(3 * np.sqrt(4 / 5*EbNoLin)) \
                    - 1 / 4 * qfunc(5 * np.sqrt(4 / 5*EbNoLin))
            elif M == 64:
                ber = 7 / 12 * qfunc(np.sqrt(2 / 7 * EbNoLin)) \
                    + 1 / 2  * qfunc(3 * np.sqrt(2 / 7 * EbNoLin)) \
                    - 1 / 12 * qfunc(5 * np.sqrt(2 / 7 * EbNoLin)) \
                    + 1 / 12 * qfunc(9 * np.sqrt(2 / 7 * EbNoLin)) \
                    - 1 / 12 * qfunc(13 * np.sqrt(2 / 7 * EbNoLin))
            else:
                ber = np.zeros_like(EbNoLin)
                for i in range(int(np.log2(np.sqrt(M)))):
                    i = i + 1
                    berk = np.zeros_like(EbNoLin)
                    for j in range(int((1 - 2 ** (-i) ) * np.sqrt(M))):
                        berk = berk + (-1) ** (np.floor(j * 2 ** (i - 1) / np.sqrt(M))) * (2 ** (i - 1) - np.floor(j * 2 ** (i - 1) / np.sqrt(M) + 1 / 2)) \
                            * qfunc((2 * j + 1) * np.sqrt(6 * k * EbNoLin / (2 * (M - 1))))
                        
                    berk = berk * 2 / np.sqrt(M)
                    ber = ber + berk 
                    
                ber = ber / np.log2(np.sqrt(M))
            
        else:
            # k is odd - rectangular QAM
            I = 2 ** (np.ceil(np.log2(M) / 2))
            J = 2 ** (np.floor(np.log2(M) / 2))
            if M == 8:
                ser = 5 / 2 * qfunc(np.sqrt(k  *EbNoLin / 3)) - 3 / 2 * (qfunc(np.sqrt(k * EbNoLin / 3))) ** 2
            else:
                ser = (4 * I * J - 2 * I - 2 * J) / M * qfunc(np.sqrt(6 * np.log2(I * J) * EbNoLin / ((I ** 2 + J ** 2 - 2)))) \
                        - 4 / M * (1 + I * J - I - J) * (qfunc(np.sqrt(6 * np.log2(I * J) * EbNoLin / ((I ** 2 + J ** 2 - 2))))) ** 2 
            
            berI = np.zeros_like(EbNoLin)
            berJ = np.zeros_like(EbNoLin)
            for i in range(int(np.log2(I))):
                i = i + 1
                berk = np.zeros_like(EbNoLin)
                for j in range(int((1 - 2 ** (- i)) * I)):
                    berk = berk + (- 1) ** (np.floor(j * 2 ** (i - 1) / I)) * (2 ** (i - 1) - np.floor(j * 2 ** (i - 1) / I + 1 / 2)) \
                        * qfunc((2 * j + 1) * np.sqrt(6 * np.log2(I * J) * EbNoLin / (I ** 2 + J ** 2 - 2)))
                
                berk = berk * 2 / I
                berI = berI + berk
            
            for i in range(int(np.log2(J))):
                i = i + 1
                berk = np.zeros_like(EbNoLin)
                for j in range(int((1 - 2 ** (- i)) * J)):
                    berk = berk + (- 1) ** (np.floor(j * 2 ** (i - 1) / J)) * (2 ** (i - 1) - np.floor(j * 2 ** (i - 1) / J + 1 / 2)) \
                        * qfunc((2 * j + 1) * np.sqrt(6 * np.log2(I * J) * EbNoLin / (I ** 2 + J ** 2 - 2)))
                
                berk = berk * 2 / J
                berJ = berJ + berk
            
            ber = (berI + berJ) / np.log2(I * J)
        return ber, ser

def get_average_power(M):
    '''
    Get the average power of a QAM constellation with minimum distance of 2 between 
    points, and modulation order M. M must be an integer power of 2. Note that this 
    internal function does not perform any input validation. It is user's responsibility 
    to ensure that the above conditions are met.
    Parameters:
    M: int
    Modulation order of QAM modulation format

    Returns:
    averagePower: float
    Average power of QAM constellation
    '''
    n_bits = int(np.log2(M))
    if M==2 or M==8:
        average_power = ((5 * M / 4) - 1 ) * 2 / 3
    elif np.mod(n_bits,2) == 0.0:
        # Cross QAM
        average_power = ((31 * M / 32) - 1 ) * 2 / 3       
    else:
        # Square QAM
        average_power = (M - 1) * 2 / 3
    return average_power

def noise_var_esti_func(x, y):
    '''
    Estimate the noise and scaling coefficient according to the input transmitted 
    signal and received signal
    Parameters:
    x: array_like
    Input transmission signal sequence
    y: array_like
    Input received signal sequence

    Returns:
    n0: float
    Noise power estimated
    scale: float
    Scaling factor estimated
    '''
    scale = optimze_scale(x, y)
    n = y / scale - x
    n0 = np.mean(np.abs(n) ** 2)
    return n0, scale

def noise_var_esti(x, y):
    '''
    According to the input transmitted signal and received signal, the average value, 
    signal power, noise power and signal-to-noise ratio of the received signal are 
    directly estimated
    Parameters:
    x: array_like
    Input transmission signal sequence
    y: array_like
    Input received signal sequence

    Returns:
    n_p: float
    Noise power estimated
    m: float
    Estimated received signal mean value 
    s_p: float
    Signal power estimated
    snr: float
    Signal to noise ratio
    '''
    n = y - x
    m = np.mean(y)
    s_p = np.mean(np.abs(x)**2)
    n_p = np.mean(np.abs(n)**2)
    snr = 10 * np.log10(s_p / n_p)
    return n_p, m, s_p, snr

def noise_var_esti_normfit(x_int, y, M):
    r""" 
    According to the input transmitted signal and received signal, the 
    received signal is fitted with normal distribution
    Parameters:
    x_int: array_like
    Integer symbol sequence corresponding to the transmitted signal
    y: array_like
    Received signal sequence
    M: int
    Modulation order of QAM signal

    Returns:
    n0: float
    Noise power estimated
    var: array_like
    The estimated noise variance at each symbol position is the same length 
    as the modulation order
    sig_p: float
    Signal power estimated
    snr: float
    Signal to noise ratio
    probability: array_like
    Statistical signal probability distribution
    """
    n0 = 0
    var = np.zeros(M)
    sig_p = 0
    mu = np.zeros(M) + 0j
    probability = np.zeros(M)
    num = y.shape[0]
    for i in range(M):
        idx = np.where(x_int == i)
        mth_sym = y[idx[0]]
        probability[i] = mth_sym.shape[0] / num
        mu[i], var[i] = normfit(mth_sym)
    n0 = np.sum(var * probability)
    sig_p = np.sum(np.abs(mu)**2 * probability) 
    snr = 10*np.log10(sig_p / n0)
    return n0, var, sig_p, snr, probability

def noise_var_non_circular_gaussian(x_int, y, M):
    r""" 
    According to the input transmitted signal and received signal, the noncircular 
    symmetric Gaussian distribution signal is fitted with normal distribution
    Parameters:
    x_int: array_like
    Integer symbol sequence corresponding to the transmitted signal
    y: array_like
    Received signal sequence
    M: int
    Modulation order of QAM signal

    Returns:
    n0: float
    Noise power estimated
    cov_hat: array_like
    Covariance estimate of each symbol position, the length is the same as the modulation order
    sig_p: float
    Signal power estimated
    snr: float
    Signal to noise ratio
    probability: array_like
    Statistical signal probability distribution
    """
    n0 = 0
    cov_hat = np.zeros(M)
    sig_p = 0
    mean_hat = np.zeros(M, 2)
    probability = np.zeros(M)
    num = y.shape[0]
    for i in range(M):
        idx = np.where(x_int == i)
        mth_sym = y[idx[0]]
        probability[i] = mth_sym.shape[0] / num
        x = np.concatenate((mth_sym.real[:,None], mth_sym.imag[:,None]), axis=1)
        cov = EmpiricalCovariance().fit(x)  # MLSE
        mean_hat[i] = cov.location_
        cov_hat[i] = cov.covariance_
        n0 = n0 + cov_hat * probability[i]
        sig_p = sig_p + np.abs(mean_hat[i])**2 * probability[i]
    
    snr = 10*np.log10(sig_p / n0)
    return n0, cov_hat, sig_p, snr, probability

def empirical_distribution(x_int, M):
    '''
    The empirical probability distribution is obtained according to the statistics of 
    the transmitted signal
    Parameters:
    x_int: array_like
    Integer symbol sequence corresponding to the transmitted signal
    M: int
    Modulation order of QAM signal

    Returns:
    probability: array_like
    Statistical empirical probability distribution
    '''
    probability = np.zeros(M)
    num = x_int.shape[0]
    for i in range(M):
        idx = np.where(x_int == i)
        mth_sym = x_int[idx[0]]
        probability[i] = mth_sym.shape[0] / num
    return probability
       

def get_soft_data(y, n_bits, bitwise_mapping):
    '''
    Obtain the position indexes of 0 and 1 in the constellation point bit 
    data for subsequent calculation of log likelihood ratio(llr). Where c0 
    contains the mapping index with bit 0, and c1 contains the mapping index 
    with bit 1
    Parameters:
    y: array_like
    Received signal sequence
    n_bits: int
    Bits per symbol
    bitwise_mapping: array_like
    Bit data of constellation points

    Returns:
    c0: array_like
    Indices of mapping which has 0 at various bit positions
    c1: array_like
    Indices of mapping which has 1 at various bit positions
    '''
    # c0 contains indices of mapping which has 0 at various bit positions
    c0 = np.zeros((2**(n_bits-1), n_bits))
    indx0 = np.where(bitwise_mapping == 0)
    for i in range(n_bits):
        idx = np.where(indx0[1]==i)
        c0[:, i] = indx0[0][idx[0]]
    
    # # c1 contains indices of mapping which has 1 at various bit positions
    c1 = np.zeros((2**(n_bits-1), n_bits))
    indx1 = np.where(bitwise_mapping == 1)
    for i in range(n_bits):
        idx = np.where(indx1[1]==i)
        c1[:, i] = indx1[0][idx[0]]

    return c0.astype(np.int16), c1.astype(np.int16)

def calcu_llr_approx(y, M, symbolOrder, symbolOrderVector, unitAveragePower, noiseVar):
    pass

def calcu_llr(y, n_bits, sym_map, c0, c1, noise_var, ** kwargs):
    '''
    According to the input received signals and parameters, the log likelihood ratio 
    is calculated, which can be used for the subsequent calculation of generalized 
    mutual information
    Parameters:
    y: array_like
    Received signal sequence
    n_bits: array_like
    Bits per symbol
    sym_map: array_like
    Constellation corresponding to the received signal
    c0: array_like
    Indices of mapping which has 0 at various bit positions
    c1: array_like
    Indices of mapping which has 1 at various bit positions
    noise_var: float
    Noise variance obtained from the received signal

    Returns:
    llr: array_like
    Calculated log likelihood ratio
    '''
    probablity = kwargs.get("probablity", np.ones(sym_map.shape[0]))
    num = y.shape[0]
    llr = np.zeros(num * n_bits)
    for i_y in range(num):
        for i_b in range(n_bits):
            d0 = np.abs(sym_map[c0[:,i_b]] - y[i_y]) ** 2
            d1 = np.abs(sym_map[c1[:,i_b]] - y[i_y]) ** 2
            logits_0 = np.sum(np.exp(-d0 / noise_var) * probablity[c0[:,i_b]])
            logits_1 = np.sum(np.exp(-d1 / noise_var) * probablity[c1[:,i_b]])
            if logits_1 == 0:
                logits_1 = sys.float_info.min
            if logits_0 == 0:
                logits_0 = sys.float_info.min
            logits = logits_0 / logits_1
            llr[i_y * n_bits + i_b] = np.log(logits)
    
    return llr

def calcu_llr_matrix(y, n_bits, sym_map, c0, c1, noise_var, ** kwargs):
    '''
    Calculate the log likelihood ratio in the form of matrix according to the input received signal and parameters
    Parameters:
    y: array_like
    Received signal sequence
    n_bits: array_like
    Bits per symbol
    sym_map: array_like
    Constellation corresponding to the received signal
    c0: array_like
    Indices of mapping which has 0 at various bit positions
    c1: array_like
    Indices of mapping which has 1 at various bit positions
    noise_var: float
    Noise variance obtained from the received signal

    Returns:
    llr: array_like
    Calculated log likelihood ratio
    '''
    probablity = kwargs.get("probablity", np.ones(sym_map.shape[0]))
    num = y.shape[0]
    llr = np.zeros(num * n_bits)
    y = y.reshape((num, 1, 1))
    d0 = np.abs(y - sym_map[c0]) ** 2
    d1 = np.abs(y - sym_map[c1]) ** 2
    p0 = probablity[c0].reshape((1, c0.shape[0], -1))
    p1 = probablity[c1].reshape((1, c1.shape[0], -1))
    if type(noise_var) == np.ndarray:
        logits_0 = np.sum(np.exp(-d0 / noise_var[c0]) * p0, axis = 1)
        logits_1 = np.sum(np.exp(-d1 / noise_var[c1]) * p1, axis = 1)
    else:
        logits_0 = np.sum(np.exp(-d0 / noise_var) * p0, axis = 1)
        logits_1 = np.sum(np.exp(-d1 / noise_var) * p1, axis = 1)
    logits_1 = np.where(logits_1==0, sys.float_info.min, logits_1)
    logits_0 = np.where(logits_0==0, sys.float_info.min, logits_0)
    logits = logits_0 / logits_1
    llr = np.log(logits)
    
    return llr

def calcu_gmi_gh():
    pass

def calcu_gmi_mc():
    pass

def calcu_gmi_integration():
    pass

def calcu_mi_gh(constell, bit_maping, snr, prob = None):
    """
    Evaluation of mutual information using Gauss-Hermite quadrature
    This function evaluates the mutual information (MI) and the generalized
    mutual information (GMI) of the 2D (complex) constellation C over an
    AWGN channel with standard deviation of the complex-valued noise equal
    to sigma_n. Evaluation is performed using the Gauss-Hermite quadrature
    which allows fast numerical integration of the mutual information. For
    GMI calculation the bit mapping B is used.
    
    taken from A. Alvarado et al., "Achievable Information Rates
    for Fiber Optics: Applications and Computations", J. Lightw. Technol. 36(2) pp. 424-439
    https://dx.doi.org/10.1109/JLT.2017.2786351
    Parameters of Gauss-Hermite taken from: http://keisan.casio.com/exec/system/1281195844
    
    Parameters:
        constell   :=      Constellation [M x 1]
        bit_maping :=      Bit mapping [M x log2(M)]
        snr        :=      Signal-to-noise ratios (Es/No, dB unit) [N x 1]
        prob       :=      Probability of each constellation point (optional) [M x 1]
    
    Returns:
        MI      :=      Mutual information in bits [N x 1]
        GMI     :=      Generalized mutual information in bits [N x 1]

    February 2018 - Dario Pilori
    """
    M = constell.shape[0]
    constell = constell.reshape(-1,1)
    if prob == None:
        prob = np.ones((M, 1)) / M  # Uninform distribution
    # Calculate sigma_n

    sigma_n = np.sqrt(np.sum(prob * np.abs(constell)**2)) * 10 ** (- snr / 20)

    # Params for Gauss-Hermite quadrature

    x = np.array([
        -3.436159118837737603327,  
        -2.532731674232789796409,
        -1.756683649299881773451,
        -1.036610829789513654178,
        -0.3429013272237046087892,	
        0.3429013272237046087892,	
        1.036610829789513654178, 
        1.756683649299881773451,
        2.532731674232789796409,
        3.436159118837737603327]).reshape(-1,1)
    w = np.array([
        7.64043285523262062916*(10**-6),
        0.001343645746781232692202,
        0.0338743944554810631362,
        0.2401386110823146864165,
        0.6108626337353257987836,
        0.6108626337353257987836,
        0.2401386110823146864165,
        0.03387439445548106313617,
        0.001343645746781232692202,
        7.64043285523262062916*(10**-6)]).reshape(-1,1)    

    # x = np.ones((10, 1))
    # Evaluate Mutual Information
    MI = 0.0
    for l in range(M):
        for m in range(x.shape[0]):
            MI = MI - prob[l] / np.pi * w[m] * np.sum(\
                w.reshape(-1) * np.log2(np.sum(np.transpose(prob) * np.exp(\
                - (np.abs(constell[l] - np.transpose(constell)) ** 2 - \
                    2 * sigma_n * np.real((x + 1j * x[m]) * (constell[l] - np.transpose(constell))))\
                /sigma_n ** 2), axis = 1)))


    # %% Evaluate Generalized Mutual Information
    # % To be optimized...

    GMI = 0.0
    for k in range(int(np.log2(M))):
        for b in [0, 1]:
            I = constell[np.where(bit_maping[:,k] == b)]
            PI = prob[np.where(bit_maping[:,k] == b)]
            PIs = np.sum(PI)
            for i in range(I.shape[0]):
                for l in range(x.shape[0]):
                    GMI = GMI - w[l] * PI[i] / np.pi * np.sum(w.reshape(-1) * np.log2(\
                        np.sum(np.transpose(prob) * np.exp(-(np.abs(I[i] - np.transpose(constell))**2 \
                            - 2 * sigma_n * np.real((x + 1j * x[l]) * (I[i] - np.transpose(constell))))\
                                /sigma_n ** 2),1) / np.sum( np.transpose(PI) * np.exp(-(np.abs(I[i] - np.transpose(I)) ** 2\
                                     - 2 * sigma_n * np.real((x + 1j * x[l]) * (I[i] - np.transpose(I)))) / \
                                        sigma_n ** 2), 1) * PIs))
    
    Pb = np.sum(prob * np.concatenate(((bit_maping+1)%2, bit_maping), axis = 1), axis = 0) 
    Pb = Pb.reshape((-1, 2))
    GMI = GMI - np.transpose(prob) @ np.log2(prob) + \
        np.sum(np.sum(Pb * np.log2(Pb)))

    return MI[0], GMI[0,0]

def calcu_mi_mc():
    pass

def calcu_mi_integration():
    pass

