import numpy as np
import torch

class DSP_Base_Module():
    
    def __init__(self, *args, **kwargs):
        self.data_type = {'numpy': np.ndarray, 'tensor': torch.Tensor}
            
    def data_mode_convert(self, x, data_mode = None):
        """ Convert data mode from numpy to tensor or from tensor to numpy
        """
        x_type = type(x)
        if data_mode is None:
            data_mode = self.data_mode
        self.data_mode_type = self.data_type[data_mode]
        if x_type != self.data_mode_type:
            if x_type == np.ndarray:
                return torch.from_numpy(x).to(self.device)
            elif x_type == torch.Tensor:
                return x.detach().cpu().numpy()
            else:
                raise AttributeError('Data type is not supported')
        else:
            return x      

    def __getarr__(self, name, value = 'empty_value'):
        if not hasattr(self, name):
            if isinstance(value,str) and value == 'empty_value':
                    raise RuntimeError(name + 'is an empty value')
            else:
                self.__dict__[name] = value
        return self.__dict__[name]
    
    def __call__(self, *args, **kwds):
        return self.forward_pass(*args, **kwds)

    def _check(self, param_name, default_value = 'empty_value', configs = None):
        if not hasattr(self, param_name):
            if configs is not None:
                if param_name in configs:
                    self.__dict__[param_name] = configs[param_name]
                else:
                    self.__dict__[param_name] = default_value
                    configs[param_name] = default_value
            else:
                if isinstance(default_value,str) and default_value == 'empty_value':
                    raise RuntimeError(param_name + 'is an empty value')
                else:
                    self.__dict__[param_name] = default_value

        else:
            if configs is not None:
                configs[param_name] = self.__dict__[param_name]
            
        return self.__dict__[param_name]

    def _config_check(self, param_name, default_value = None, configs = {}):
        if param_name not in configs:
            configs[param_name] = default_value
        return configs[param_name]

class DSP_Base_Module_with_grad(torch.nn.Module):
    
    def __init__(self, *args, **kwargs):
        super(DSP_Base_Module_with_grad, self).__init__() 
        self.data_type = {'numpy': np.ndarray, 'tensor': torch.Tensor}
            
    def data_mode_convert(self, x, data_mode = None):
        """ Convert data mode from numpy to tensor or from tensor to numpy
        """
        x_type = type(x)
        if data_mode is None:
            data_mode = self.data_mode
        self.data_mode_type = self.data_type[data_mode]
        if x_type != self.data_mode_type:
            if x_type == np.ndarray:
                return torch.from_numpy(x).to(self.device)
            elif x_type == torch.Tensor:
                return x.detach().cpu().numpy()
            else:
                raise AttributeError('Data type is not supported')
        else:
            return x      

    def __getarr__(self, name, value = 'empty_value'):
        if not hasattr(self, name):
            if isinstance(value,str) and value == 'empty_value':
                    raise RuntimeError(name + 'is an empty value')
            else:
                self.__dict__[name] = value
        return self.__dict__[name]

    def _check(self, param_name, default_value = 'empty_value', configs = None):
        if not hasattr(self, param_name):
            if configs is not None:
                if param_name in configs:
                    self.__dict__[param_name] = configs[param_name]
                else:
                    self.__dict__[param_name] = default_value
                    configs[param_name] = default_value
            else:
                if isinstance(default_value,str) and default_value == 'empty_value':
                    raise RuntimeError(param_name + 'is an empty value')
                else:
                    self.__dict__[param_name] = default_value

        else:
            if configs is not None:
                configs[param_name] = self.__dict__[param_name]
            
        return self.__dict__[param_name]

    def _config_check(self, param_name, default_value = None, configs = {}):
        if param_name not in configs:
            configs[param_name] = default_value
        return configs[param_name]