import torch 
import numpy as np
# import os
from . constellations import qammap, gs_qammap
from ... comm_tools import base_conversion as conver
from ... comm_tools.normalization import norm_1d



def nn(encoder, condi, bits_per_sym):
    """
    Using the trained encoder neural network to obtain the constellation of modulation signal
    Parameters:
    encoder:
    Neural network model of encoder
    condi: array_like
    Neural network input conditions related to optical fiber transmission under current conditions
    bits_per_sym: int
    Bit number of the signal to be modulated

    Returns:
    map: array_like
    Constellation diagram of modulation signal output by encoder
    """
    x = conver.sym2bit(np.arange(0, 2**bits_per_sym), bits_per_sym, 'nature') # 输入
    x = torch.from_numpy(x).double().to(condi.device)
    sym_map = encoder(x, condi)
    map = sym_map[:, 0] + 1j * sym_map[:, 1]
    return map

def pam(M, mode = 'single', codemode = 'gray'):
    '''
    According to the bit number, the constellation of pulse amplitude modulation (PAM) signal is obtained
    Parameters:
    M: int
    Bit number of the signal to be modulated
    mode: str, optional
    Modulation mode (single or bipolar)
    codemode: str, optional
    Encoding mode (natural binary encoding or gray encoding)

    Returns:
    map: array_like
    Obtained pulse amplitude modulation signal constellation
    '''
    K = 2 ** M
    x = conver.sym2bit(np.arange(0, K), M, 'nature') # 输入
    y = conver.bit2sym(x, M, codemode).reshape(-1)
    if mode == 'single':        # 单极性
        map = y
    else:
        map = 2 * y - (2 ** M -1)  # 双极性
    
    return map

def qam(M, gs = 0):
    '''
    According to the bit number, the constellation of quadrature amplitude 
    modulation (QAM) signal is obtained
    Parameters:
    M: int
    Bit number of the signal to be modulated
    gs: bool, optional
    Represent whether to use geometric shaping to modulate the signal at present 
    (the default is 0, and the current version is unavailable)

    Returns:
    map: array_like
    Constellation of quadrature amplitude modulation signal obtained

    Raises:
    ValueError
    'does not support the gs constellations now'
    '''
    if not gs:
        y = qammap(M)
    else:
        raise ValueError('does not support the gs constellations now')

    y = y[:, 0] + 1j * y[:, 1]
    map = norm_1d(y, mean = 1)

    return map
