import torch
import numpy as np


def bit2sym(x, M, mode = 'nature'):
    '''
    Base conversion, converting binary bit sequence to decimal symbol sequence
    Parameters:
    x: array_like
    Input bit sequence to be converted
    M: int
    Bits per symbol
    mode: str, optional
    Encoding mode. The optional mode is natural binary coding or gray coding, and the 
    default mode is natural coding

    Returns:
    sym_seq: array_like
    Symbol sequence obtained after base conversion

    Raises:
    ValueError
    'bit2sym error: input bits number does not match bits number per symbol'If the length 
    of the input sequence is not an integer multiple of the bits per symbol, an error is reported
    '''
    k = 2 ** M
    size = x.shape
    x = x.reshape(-1)
    if x.shape[0] % M != 0:
        raise ValueError('bit2sym error: input bits number does not match bits number per symbol')
    else:
        sym_num = int(x.shape[0] / M)
        x1 = x.reshape(-1 ,M)
        sym = np.zeros((sym_num, 1))
        # for i in range(sym_num):
        #     for j in range(M):
        #         sym[i] += x1[i][M - j - 1] * (2 ** j)    
        base_matrix = np.zeros((M, 1))
        for i in range(M):
            base_matrix[i] = 2 ** (M - i - 1)
        sym = x1 @ base_matrix
        if mode == 'nature':    # 自然编码
            if len(size)> 1 and size[1] > 1:
                # need to reshape
                sym = sym.reshape((size[0], -1))
            sym_seq = sym.astype(int)
            return sym_seq  
        elif mode == 'gray':    # 格雷码编码
            sym_gray = np.zeros((sym_num, 1))
            bit_gray = np.zeros((sym_num, M))
            '''
            n 位二进制的格雷码生成方式: 
            1. n 位格雷码的 前 2**(n-1) 个代码字等于 n-1 位格雷码的代码字，按顺序书写，加前缀 0
            2. n 位格雷码的 后 2**(n-1) 个代码字等于 n-1 位格雷码的代码字，按逆序书写，加前缀 1
            '''      
            list = ['0', '1']
            for i in range(1, M):
                left  = ['0' + i for i in list]
                right = ['1' + i for i in list[::-1]]   # list[::-1]表示翻转
                list = left + right
            for i in range(sym_num):
                for j in range(M):
                    bit_gray[i][M - j - 1] = list[int(sym[i])][M - j - 1]
                    sym_gray[i] += bit_gray[i][M - j - 1] * (2 ** j) 
            if len(size)> 1 and size[1] > 1: 
                # need to reshape
                sym_gray = sym_gray.reshape((size[0], -1))
            sym_seq = sym_gray.astype(int)
            return sym_seq       

def sym2bit(sym, M, mode = 'nature', out = 'parall'):
    """
    Base conversion, converting a decimal symbol sequence to a binary bit sequence
    Parameters:
    sym: array_like
    Input symbol sequence to be converted
    M: int
    Bits per symbol
    mode: str, optional
    Encoding mode. The optional mode is natural binary coding or gray coding, and the 
    default mode is natural coding
    out: str, optional
    Output mode of bit sequence. The optional mode is parallel or serial, and the default 
    is serial output

    Returns:
    bit_seq: array_like
    Bit sequence obtained after base conversion
    """
    size = sym.shape
    sym_num = size[0]
    quotient = sym.reshape(size + (1,))
    for i in range(M):
        quotient, reminder = np.divmod(quotient, 2)
        # reminder = reminder.reshape(size + (1,))
        if i == 0:
            bit_ = reminder
        else:
            bit_ = np.concatenate((reminder, bit_), axis = len(size))
    # bit_T = np.transpose(bit_, (0, 2, 1))
    if mode == 'nature':
        if out == 'seq':
            bit_seq = bit_.reshape(-1)
        else:
            bit_seq = (bit_.reshape((size[0], -1))) 
        return bit_seq

    elif mode == 'gray':
        '''
        格雷码转换为自然二进制编码：
        1、最高位不变  n[n]=g[n]
        2、当前位等于格雷码当前位异或上一位 n[k]=g[k]+n[n-1] mod 2
        '''
        bit_gray = np.zeros((sym_num, M))
        for i in range(M):
            if i == 0:
                bit_gray[:,i] = bit_[:,i]
            else:
                bit_gray[:,i] = (bit_[:,i] + bit_[:,i-1]) % 2
        if out == 'seq':
            bit_seq = bit_gray.reshape(-1)
            return bit_seq
        else:
            bit_seq = bit_gray
            return bit_seq # 并排序列
