import torch
import numpy as np

def awgn_real(x, snr_db = 0, n_p = None, device='cpu', data_mode = 'numpy'):
    '''
    awgn 信道
    x [N,] complex
    snr 单位是 dB 
    AWGN 支持tensor以及numpy数据类型
    n_p: 是否给定噪声功率，如果给出噪声功率，则直接加噪声
    如果n_p 是 none
    则根据snr和信号功率加入noise
    
    '''
    if data_mode == 'tensor':
        sig_p = torch.mean(x ** 2).detach()
        noise = torch.randn((x.shape[0]), device = device) 
    elif data_mode == 'numpy':    
        sig_p = np.mean(x ** 2)
        noise = np.random.randn(x.shape[0])
    else:
        raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
    if n_p != None:
        n_power = 10 ** ((n_p - 30 ) / 10)
    else:
        snr_w = 10 ** (snr_db / 10)         # W
        n_power = sig_p / snr_w 
    noise = noise * ((n_power/2) ** 0.5)
    y = x + noise
    return y

def awgn_complex(x, snr_db = 0, n_p = None, device='cpu', data_mode = 'numpy'):
    '''
    awgn 信道
    x [N,] complex
    snr 单位是 dB 
    AWGN 支持tensor以及numpy数据类型
    n_p: 是否给定噪声功率，如果给出噪声功率，则直接加噪声
    如果n_p 是 none
    则根据snr和信号功率加入noise
    
    '''
    if data_mode == 'tensor':
        sig_p = torch.mean(torch.abs(x) ** 2).detach()
        noise = torch.randn((x.shape[0],2), device = device) 
    elif data_mode == 'numpy':    
        sig_p = np.mean(np.abs(x) ** 2)
        noise = np.random.randn(x.shape[0],2)
    else:
        raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
    if n_p != None:
        n_power = 10 ** ((n_p - 30 ) / 10)
    else:
        snr_w = 10 ** (snr_db / 10)         # W
        n_power = sig_p / snr_w 
    noise = (noise[:, 0] + 1j * noise[:, 1]) * ((n_power/2) ** 0.5)
    y = x + noise
    return y

def awgn_pol(sigin, snr_db = 0, n_p = None, device='cpu', data_mode = 'numpy'):
    '''
    awgn 信道
    singin [N,] complex
    snr 单位是 dB 
    AWGN 支持tensor以及numpy数据类型
    n_p: 是否给定噪声功率，如果给出噪声功率，则直接加噪声
    如果n_p 是 none
    则根据snr和信号功率加入noise
    '''
    sigout = []
    for i_p in range(len(sigin)):
        sigout.append(awgn_complex(sigin[i_p], snr_db, n_p, device, data_mode))
    return sigout

if __name__ == '__main__':
    sigin = np.random.randn(1000, 2)
    sigin = sigin[:,0] + 1j * sigin[:,1]
    sigout = awgn(sigin)
    sigin = torch.from_numpy(sigin)
    sigout = awgn(sigin,data_mode = 'tensor')