import numpy as np
import torch

def ber(x, y, data_mode = 'numpy'):
    r'''
    Calculate BER (bit error rate) according to the input variables
    Parameters:
    x: array_like
    Input data array, usually the transmission data
    y: array_like 
    Input data array, usually the received data
    data_mode:str, optional
    Format of input data('np': numpy array input. 't' : torch tensor input)

    Returns:
    ber: float
    Calculated BER
    ber_per_num: float
    Calculated BER per bit
    '''
    x = x.reshape(-1)
    if len(y.shape) > 1:
        y = y.reshape(-1)
    num = y.shape[0]
    # correct = 0
    if data_mode == 'numpy':
        b = (x + y) % 2
        ber = np.sum(b)
    else:
        correct = (x.eq(y).sum()).float()
        ber = num - correct
    ber_per_num = ber / num
    return ber, ber_per_num