import numpy as np
import torch
from IFTS.fiber_simulation.rx_dsp.carrier_phase_estimation import coarse_estimation, viterbi_viterbi
from IFTS.fiber_simulation.base.base_dsp import DSP_Base_Module
from IFTS.fiber_simulation.utils.show_progress import progress_info, progress_info_return

class CPE(DSP_Base_Module):
    def __init__(self, sym_rate, window_size, mode ='coarse', block_num=1,\
        parallelism=1, data_aided=1, infor_print=1, data_mode = 'numpy', *args, **kwargs):
        super().__init__()
        r"""
            Initialization function of the carrier phase estimation design class.
            This function initializes the carrier phase estimation parameters.
            Parameters: sym_rate:int
                            The symbol rate, or the baud rate.
                        window_size:int
                            Sliding window size.
                        mode:str
                            Algorithm mode, defaults to 'coarse'.
                        block_num:int
                            Number of signal blocks, defaults to 1.
                        parallelism:int
                            Number of parallel paths, defaults to 1.
                        data_aided:int
                            Enabled in coaster mode, defaults to 1.
                            The origin signal is used to assist in solving for phase offset.
                        infor_print:int
                            The information print parameter, defaults to 1.
                        data_mode:str
                            Data mode. Defaults to 'numpy'.
            Raises:
                RuntimeError: 'block_num must be an integer multiple of parallize'
                If the number of signal blocks cannot be divisible by the number of parallel paths.
        """
        self.mode = mode
        self.sym_rate = sym_rate
        self.block_num = block_num
        self.parallelism = parallelism
        if block_num % parallelism != 0:
            raise RuntimeError('block_num must be an integer multiple of parallize')
        self.iter_num = int(block_num / parallelism)
        self.window_size = window_size
        self.data_aided = data_aided
        self.infor_print = infor_print
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')
        
    def forward_pass(self, rx_sig, tx_sig, **kwargs):
        r"""
            This function executes the carrier phase estimation module.
            Parameters: rx_sig:list
                             Signal sequence of receiver.
                        tx_sig:list
                             Signal sequence of transmitter.
            Return:     rx_sig: list
                            The signal sequence of receiver after phase offset compensation.
                        phase:ndarray
                            The phase offset in receiver.
        """
        @progress_info_return(total_num = self.iter_num,\
            infor_print = self.infor_print, discription = ' --CPE_'+self.mode)
        def run_compensation(rx_sig, tx_sig, **kwargs):
            pbar = kwargs.get('pbar', 0)
            phase_x, phase_y = [], []
            # Sliding window
            for i in range(self.iter_num):
                rx_x = rx_sig[0][self.slide_size*i: self.slide_size*i+self.choose_size]
                rx_y = rx_sig[1][self.slide_size*i: self.slide_size*i+self.choose_size]
                rx_x = self.data_mode_convert(rx_x[self.block_idx])
                rx_y = self.data_mode_convert(rx_y[self.block_idx])
                if self.data_aided:
                    tx_x = tx_sig[0][self.slide_size*i: self.slide_size*i+self.choose_size]
                    tx_y = tx_sig[1][self.slide_size*i: self.slide_size*i+self.choose_size]
                    tx_x = self.data_mode_convert(tx_x[self.block_idx])
                    tx_y = self.data_mode_convert(tx_y[self.block_idx])
                else:
                    tx_x = tx_y = None
                if self.mode == 'vv':
                    # viterbi_viterbi mode
                    phase_hat_x, rx_x = viterbi_viterbi.vv_cpe(x, move_mean = 1, window_size = self.window_size)
                    phase_hat_y, rx_y = viterbi_viterbi.vv_cpe(y, move_mean = 1, window_size = self.window_size)
                elif self.mode == 'coarse':
                    # data aided in coarse mode
                    phase_hat_x, rx_x = coarse_estimation.main(rx_x, tx_x)
                    phase_hat_y, rx_y = coarse_estimation.main(rx_y, tx_y)
                    
                if self.data_mode == 'tensor':
                    if i == 0:
                        out_x = rx_x.reshape((-1))
                        out_y = rx_y.reshape((-1))
                    else:
                        out_x = torch.cat((out_x, rx_x.reshape((-1))), dim = 0)
                        out_y = torch.cat((out_y, rx_y.reshape((-1))), dim = 0)
                else:
                    if i == 0:
                        out_x = rx_x.reshape((-1))
                        out_y = rx_y.reshape((-1))
                    else:
                        out_x = np.concatenate((out_x, rx_x.reshape((-1))), axis = 0)
                        out_y = np.concatenate((out_y, rx_y.reshape((-1))), axis = 0)
                # The phase offset
                phase_x.append(phase_hat_x)
                phase_y.append(phase_hat_y)
                if type(pbar) != type(0):
                    pbar.update(1)
                
            return [out_x, out_y], [phase_x, phase_y]  

        self.sym_map = kwargs.get("sym_map")
        self.__get_block_idx__(rx_sig[0].shape[0])
        rx_sig, phase = run_compensation(rx_sig, tx_sig)
        return rx_sig, phase

    def __get_block_idx__(self, sig_size):
        r"""
            This function executes batch, parallel, slide operation in signal
            before the CPE algorithm.
            Args: block_size:int
                    Signal block size.
                  block_idx: ndarray
                      Index of signal matrix.
                  slide_size:int
                      Window slide size.
                  choose_size:int
                      Length of chosen signal, which is the number of samples per slide.
        """
        self.block_size = int(sig_size / (self.block_num))
        self.block_idx = np.arange(self.parallelism*self.block_size).reshape((self.parallelism, self.block_size))
        self.block_idx = self.block_idx.astype(np.int64)
        self.slide_size = self.parallelism * self.block_size
        self.choose_size = self.parallelism * self.block_size