import numpy as np
import torch
from functools import reduce, partial

class Optics_Base_Module:
    
    def __init__(self):
        self.data_type = {'numpy': np.ndarray, 'tensor': torch.Tensor}
        
    def data_mode_convert(self, x, data_mode = None):
        """ Convert data mode from numpy to tensor or from tensor to numpy
        """
        x_type = type(x)
        if data_mode is None:
            self.data_mode_type = self.data_type[self.data_mode]
        else:
            self.data_mode_type = self.data_type[data_mode]

        if x_type != self.data_mode_type:
            if x_type == np.ndarray:
                return torch.from_numpy(x).to(self.device)
            elif x_type == torch.Tensor:
                return x.detach().cpu().numpy()
            else:
                raise AttributeError('Data type is not supported')
        else:
            return x        
    
    def init_func(self, name, module, *args, **kwargs):
        """
        Finds a function handle with the name given as 'type' in config, and returns the
        function with given arguments fixed with functools.partial.

        `function = config.init_ftn('name', module, a, b=1)`
        is equivalent to
        `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
        """
        module_name = name['type']
        module_args = dict(name['args'])
        assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return partial(getattr(module, module_name), *args, **module_args)
    
    def init_obj(self, name, module, *args, **kwargs):
        """
        Finds a function handle with the name given as 'type' in config, and returns the
        instance initialized with corresponding arguments given.

        `object = config.init_obj('name', module, a, b=1)`
        is equivalent to
        `object = module.name(a, b=1)`
        """
        module_name = self[name]['type']
        module_args = dict(self[name]['args'])
        assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return getattr(module, module_name)(*args, **module_args)
    
    def __getarr__(self, name, value):
        if not hasattr(self, name):
            self.__dict__[name] = value
        return self.__dict__[name]
    
    def _config_check(self, param_name, default_value, configs):
        if not hasattr(configs, param_name):
            configs[param_name] = default_value
        return configs[param_name]
    
    def __call__(self, *args, **kwds):
        return self.forward_pass(*args, **kwds)