import numpy as np
import torch
import torch
import numpy as np
import matplotlib.pyplot as plt
from IFTS.fiber_simulation.utils.show_progress import progress_info
from IFTS.fiber_simulation.base.base_optics import Optics_Base_Module

class ICR(Optics_Base_Module):
    def __init__(self, args,data_mode = 'numpy', **kwargs):
        """
        This is a class modeling the integral coherent receiver(ICR).

        ICR is derived from Optics_Base_Modules, which is used to simulate ICR.
        In simulation, an ICR object will add the receiver noise to the receiverd signal.
        The noise can be modeled as additive Gaussian noise.

        ...

        Attributes
        ----------
        mode : str,{'naive'}
            The mode of the ICR. Only 'naive' is supported now.
        data_mode : optional,{'numpy','tensor}
            The data type used in operation. Defaults to'numpy'
        rand_seed : int
            The seed to control the generator. Defaults to -1.
        upsam : int
            The upsample rate,equals the channel sample rate divided 
            by the transmitter sample rate.
        n_power_dBm: float
            The noise power, in dBm.
        """
        super().__init__()
        self.data_mode = data_mode
        self.rand_seed = -1
        self.device = kwargs.get('device', 'cpu') 
        self.constant_pi = kwargs.get('constant_pi', np.pi)
        self.mode = 'naive' 
        self.init(**args)

    def init(self, upsam, n_power_dBm, **kwargs):
        self.upsam = upsam
        self.n_power_dBm = n_power_dBm
        self.n_power_w = self.upsam * 10 ** ((n_power_dBm - 30 ) / 10) 
        for key in kwargs:
            self.__dict__[key] = kwargs[key]

    def forward_pass(self, sigin, rand_seed):
        """
        Pass the signal onto the ICR module.
        
        Wrapper function. naive_pass is called when mode='naive'.

        Parameters
        ----------
        sigin : list
            The signal input to ICR.
        rand_seed : int
            The seed to control the generator.

        Returns
        -------
        sigout : list
            The output signal of ICR.

        Raises
        ------
        NotImplementedError
        When the set mode is not supported.
        """
        if self.mode == 'naive':
            sigout = self.__naive_pass__(sigin, rand_seed)
            return sigout
        else:
            raise NotImplementedError("Mode" + self.mode + ' is not supported' )

    def __naive_pass__(self, x, rand_seed):
        """Pass the signal into the ICR module """
        if self.add_noise:
            y = []
            for i_p in range(len(x)):   # Add noise to two polarzations respectively
                y.append(self.__add_noise__(x[i_p], rand_seed+i_p*10026))
        else:
            y = x
        return y

    def __add_noise__(self, x, rand_seed):
        """Add receiver noise to the signal x.
           
           Parameters
           ----------
           x : list
              The input signal of ICR.
           rand_seed : int
              The seed to control the generator.

           Returns
           -------
           x : list
              The signals with the receiver noise.

           Raises
           ------
           "ERROR: data_mode is not defined as 'tensor' or 'numpy' "
            When the value of data_mode is not legitmate.
        """
        if self.data_mode == 'numpy':
            if rand_seed == -1:
                rng_i = np.random.default_rng()
                rng_q = np.random.default_rng()
            else:
                rng_i = np.random.default_rng(rand_seed) # use different seeds for the real and imaginary parts
                rng_q = np.random.default_rng(rand_seed + 10026)
            num = len(x)
            n = np.sqrt(self.n_power_w/2) * (rng_i.standard_normal(num)\
                + 1j * rng_q.standard_normal(num)) 
        elif self.data_mode =='tensor':
            rng_i = torch.Generator(device=self.device)
            rng_q = torch.Generator(device=self.device)
            if rand_seed == -1:
                rng_i.seed()
                rng_q.seed()
            else:
                rng_i = rng_i.manual_seed(rand_seed)            
                rng_q = rng_q.manual_seed(rand_seed + 10026) 
            num = len(x)
            n = np.sqrt(self.n_power_w/2) * (torch.randn(num, generator = rng_i, device = self.device) \
                + 1j * torch.randn(num, generator = rng_q, device = self.device))
        else:
            raise AttributeError("ERROR:data_mode is not defined as 'tensor' or 'numpy' ")
        x = x + n
        return x
