import numpy as np
import torch
from ... base.base_dsp import DSP_Base_Module
from ...comm_tools.base_conversion import sym2bit, bit2sym

class QAMDemod(DSP_Base_Module):
    '''
    QAMDemod class is is a subclass of DSP_Base_Module, which contains the operation functions 
    that need to be executed in the process of signal demodulation, such as llr calculation, 
    decision in the process of signal demodulation, llr decoding, etc. Class is used by the signal 
    module to demodulate the signal.
    '''
    def __init__(self, mode, order, probability = None, data_mode = 'numpy', *args, **kwargs):
        '''
        Initialization parameters of class QAMDemod, obtain data format, bits per symbol, signal 
        probability distribution, device and other parameters.
        Parameters:
        mode: str
        Signal demodulation mode
        order: int
        Modulation order of QAM signal
        probability: array_like, optional
        Probability distribution of signal. Default no input
        data_mode: str, optional
        Data format. The default is numpy
        *args, **kwargs
        Pass in an indefinite number of other parameters, both non key value pairs and key value pairs
        
        Returns:
        This function has no return value
        '''
        super().__init__()
        # 'bit', 'integer', 'llr', or 'approxllr'
        self.mode = mode
        self.order = order
        self.bits_per_sym = int(np.sqrt(order))
        self.probability = probability
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')
        self.__get_bit_mapping__()

    def init(self, *args, **kwargs):
        '''
        Initialize other parameters.
        Parameters:
        *args, **kwargs
        Pass in an indefinite number of other parameters, both non key value pairs and key value pairs

        Returns:
        This function has no return value
        '''
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
        
    def forward_pass(self, rx_sig, sym_map, noise_power = None):
        '''
        Execution function of demodulation function. Demodulate the signal according to the signal 
        demodulation mode. The current version only supports hard decision based on constellation
        Parameters:
        rx_sig: array_like
        Receiving end signal, i.e. signal to be demodulated
        sym_map: array_like
        Modulation signal constellation
        noise_power: float, optional
        Noise power. Optional input, no input by default

        Returns:
        sym_idx: array_like
        Demodulated signal index value
        Raises:
        RuntimeError
        Display when demodulation mode is not within range:'mode value error'
        '''
        if self.mode == 'int':
            self.__hard_decision__(rx_sig, sym_map)
            return self.sym_idx
        else:
            raise RuntimeError('mode value error')

    def get_bit(self, x, is_int = 0, is_llr = 0):
        '''
        Obtain the bit sequence of the demodulated signal
        Parameters:
        x: array_like
        Demodulated signal index value
        is_int: bool, optional
        Whether to make a hard decision
        is_llr: bool, optional
        Whether to use llr decoding for soft decision

        Returns:
        bit_seq: array_like
        Bit sequence of demodulated signal
        Raises:
        RuntimeError
        'Both is_int and is_llr is set False'
        '''
        if is_int:
            bit_seq = self.bit_map[x]
            return bit_seq
        elif is_llr:
            bits = self.__decode_llr__(x)
            bit_seq = self.data_mode_convert(bits)
            return bit_seq
        else:
            raise RuntimeError('Both is_int and is_llr is set False')
            
    def __get_llr__(self, y, n_bits, sym_map, c0, c1, noise_var, ** kwargs):
        '''
        According to the input received signals and parameters, the log likelihood ratio 
        is calculated, which can be used for the subsequent calculation of generalized 
        mutual information
        Parameters:
        y: array_like
        Received signal sequence
        n_bits: array_like
        Bits per symbol
        sym_map: array_like
        Constellation corresponding to the received signal
        c0: array_like
        Indices of mapping which has 0 at various bit positions
        c1: array_like
        Indices of mapping which has 1 at various bit positions
        noise_var: float
        Noise variance obtained from the received signal

        Returns:
        llr: array_like
        Calculated log likelihood ratio
        '''
        probablity = kwargs.get("probablity", np.ones(sym_map.shape[0]))
        num = y.shape[0]
        llr = np.zeros(num * n_bits)
        y = y.reshape((num, 1, 1))
        d0 = np.abs(y - sym_map[c0]) ** 2
        d1 = np.abs(y - sym_map[c1]) ** 2
        p0 = probablity[c0].reshape((1, c0.shape[0], -1))
        p1 = probablity[c1].reshape((1, c1.shape[0], -1))
        logits_0 = np.sum(np.exp(-d0 / noise_var) * p0, axis = 1)
        logits_1 = np.sum(np.exp(-d1 / noise_var) * p1, axis = 1)
        logits = logits_0 / logits_1
        llr = np.log(logits)
        return llr

    def __get_bit_mapping__(self):
        '''
        Get the bit sequence corresponding to QAM constellation
        '''
        int_seq = np.arange(self.order)
        bit_map = sym2bit(int_seq, M = self.bits_per_sym, out = 'parall')
        self.bit_map = self.data_mode_convert(bit_map)
    
    def __hard_decision__(self, sigin, sym_map):
        '''
        Demodulate the signal in the way of hard decision, and use the argmin function to find the 
        closest index value of the constellation corresponding to the signal
        Parameters:
        sigin: array_like
        Signal to be demodulated
        sym_map: array_like
        Modulation signal constellation
        '''
        self.sym_idx =  np.argmin(np.abs(sigin[...,None] - sym_map), axis=-1)

    def __decode_llr__(self, llr):
        '''
        llr decoding
        '''
        return np.where(llr > 0, 0, 1)

        


