import numpy as np
import torch


def np2tensor(x, device = 'cuda:0'):
    return torch.from_numpy(x).to(device)

def tensor2np(x):
    return x.cpu().numpy()

def covert_main(x, data_mode = 'tensor', **kwargs):
    x_type = type(x)
    if x_type is np.ndarray and data_mode == 'tensor':
        device = kwargs.get('device', 'cuda:0')
        return np2tensor(x, device)
        
    elif x_type is torch.Tensor and data_mode == 'tensor':
        return tensor2np(x)
    
def covert(x, data_mode = 'tensor', **kwargs):
    x_type = type(x)
    if x_type is np.ndarray and data_mode == 'tensor':
        device = kwargs.get('device', 'cuda:0')
        x =  np2tensor(x, device)
    elif x_type is torch.Tensor and data_mode == 'tensor':
        x =  tensor2np(x)
        
    def run_func(func):
        def wrapper(*args, **kwargs):
            func()
        return wrapper
    return run_func
    