import numpy as np
import torch
from IFTS.fiber_simulation.base.base_dsp import DSP_Base_Module

class IQ_Balance(DSP_Base_Module):
    
    def __init__(self, mode, data_mode, *args, **kwargs):
        r"""
            Initialization function of the IQ_Balance class.
            This function initializes the iq balance parameters.
            Parameters: mode:str
                            IQ balancing algorithm type, this code only supports GSOP now.
                        data_mode:str
                            data mode. Defaults to tensor.
         """
        super().__init__()
        self.mode = mode
        self.data_mode = data_mode
        self.device = kwargs.get('device', 'cpu')

    def forward_pass(self, x):
        r"""
            Gram-Schmidt orthogonalization procedure (GSOP)
            GSOP is based on defining a new vector that is
            orthogonal to the initially selected vector.
        """
        x = self.data_mode_convert(x)
        if self.mode == 'GSOP':
            # Orthogonalization
            if self.data_mode == 'tensor':
                x_real = x.real
                x_imag = x.imag - (torch.sum(x.imag * x_real) / torch.sum(x_real ** 2)) * x_real
            else:
                x_real = np.real(x)
                x_imag = np.imag(x) - (np.sum(x.imag * x_real) / np.sum(x_real ** 2)) * x_real
            y = x_real + 1j * x_imag
            return y
        