import torch
import torch.nn.functional as f

def conv1d(x, kernel):
    """
    
    Parameters
    ----------
    x : [B, C, N], [B, N] or [N] Input tensor
    kernel : [B, C, M], [B, M] or [M] Kernel tenos, M<N
    Returns
    -------
    out : [..., N + M - 1] Output tensor
    """
    def check_kernel(k):
        if len(k.shape) != dim_num:
            if len(k.shape) == 1:
                k = k.expand((x.shape[0], x.shape[1], -1))
            else:
                k = k.expand((-1, x.shape[1], -1))
        return k
    num = x.shape[-1]
    dim_num = len(x.shape)
    kernel_size = x.shape[-1]
    if dim_num == 3:
        kernel = check_kernel(kernel)
                
    elif dim_num == 2:
        x = x.reshape((-1, 1, num))
        kernel = check_kernel(kernel)
    elif dim_num == 1:
        x = x.reshape((1, 1, num))
        kernel = check_kernel(kernel)    
    else:
        raise AttributeError('Input shape is not supported')    
    out = f.conv1d(x, kernel, bias=None, stride=1, padding=kernel_size-1)
        
    return out