import numpy as np

def berawgn(EbNo, mode, M):
    '''
    Bit error rate (BER) and symbol error rate (SER) for uncoded 
    AWGN channels.The function returns the BER for PAM or QAM over 
    anuncoded AWGN channel with coherent demodulation.
    Parameters:
    EbNo: float
    bit energy to noise power spectral density ratio (in dB)
    mode: str
    modulation type, either 'pam' or 'qam'
    M: int
    alphabet size, must be a positive integer power of 2

    Returns:
    ber: float
    bit error rate
    ser: float
    symbol error rate
    '''
    if M > 1:
        k = np.log2(M)
    else:
        k = 0.5
    EbNoLin = 10 ** ( EbNo / 10)
    if mode == 'qam':
        if np.ceil(k / 2) == k / 2 :
            # k is even - square QAM
            ser = 4 * (np.sqrt(M) - 1) / np.sqrt(M) * qfunc(np.sqrt(3 * k / (M - 1) * EbNoLin)) \
                    - 4 * ((np.sqrt(M) - 1) / np.sqrt(M)) ** 2 * (qfunc(np.sqrt(3 * k / (M - 1) * EbNoLin))) ** 2
            if M == 4:
                ber = qfunc(np.sqrt( 2 * EbNoLin))
            elif M == 16:
                ber = 3 / 4 * qfunc(np.sqrt(4 / 5 * EbNoLin)) \
                    + 1 / 2 * qfunc(3 * np.sqrt(4 / 5*EbNoLin)) \
                    - 1 / 4 * qfunc(5 * np.sqrt(4 / 5*EbNoLin))
            elif M == 64:
                ber = 7 / 12 * qfunc(np.sqrt(2 / 7 * EbNoLin)) \
                    + 1 / 2  * qfunc(3 * np.sqrt(2 / 7 * EbNoLin)) \
                    - 1 / 12 * qfunc(5 * np.sqrt(2 / 7 * EbNoLin)) \
                    + 1 / 12 * qfunc(9 * np.sqrt(2 / 7 * EbNoLin)) \
                    - 1 / 12 * qfunc(13 * np.sqrt(2 / 7 * EbNoLin))
            else:
                ber = np.zeros_like(EbNoLin)
                for i in range(np.log2(np.sqrt(M))):
                    i = i + 1
                    berk = np.zeros_like(EbNoLin)
                    for j in range(int((1 - 2 ** (-i) ) * np.sqrt(M))):
                        berk = berk + (-1) ** (np.floor(j * 2 ** (i - 1) / np.sqrt(M))) * (2 ** (i - 1) - np.floor(j * 2 ** (i - 1) / np.sqrt(M) + 1 / 2)) \
                            * qfunc((2 * j + 1) * np.sqrt(6 * k * EbNoLin / (2 * (M - 1))))
                        
                    berk = berk * 2 / np.sqrt(M)
                    ber = ber + berk 
                    
                ber = ber / np.log2(np.sqrt(M))
            
        else:
            # k is odd - rectangular QAM
            I = 2 ** (np.ceil(np.log2(M) / 2))
            J = 2 ** (np.floor(np.log2(M) / 2))
            if M == 8:
                ser = 5 / 2 * qfunc(np.sqrt(k  *EbNoLin / 3)) - 3 / 2 * (qfunc(np.sqrt(k * EbNoLin / 3))) ** 2
            else:
                ser = (4 * I * J - 2 * I - 2 * J) / M * qfunc(np.sqrt(6 * np.log2(I * J) * EbNoLin / ((I ** 2 + J ** 2 - 2)))) \
                        - 4 / M * (1 + I * J - I - J) * (qfunc(np.sqrt(6 * np.log2(I * J) * EbNoLin / ((I ** 2 + J ** 2 - 2))))) ** 2 
            
            berI = np.zeros_like(EbNoLin)
            berJ = np.zeros_like(EbNoLin)
            for i in range(int(np.log2(I))):
                i = i + 1
                berk = np.zeros_like(EbNoLin)
                for j in range(int((1 - 2 ** (- i)) * I)):
                    berk = berk + (- 1) ** (np.floor(j * 2 ** (i - 1) / I)) * (2 ** (i - 1) - np.floor(j * 2 ** (i - 1) / I + 1 / 2)) \
                        * qfunc((2 * j + 1) * np.sqrt(6 * np.log2(I * J) * EbNoLin / (I ** 2 + J ** 2 - 2)))
                
                berk = berk * 2 / I
                berI = berI + berk
            
            for i in range(int(np.log2(J))):
                i = i + 1
                berk = np.zeros_like(EbNoLin)
                for j in range(int((1 - 2 ** (- i)) * J)):
                    berk = berk + (- 1) ** (np.floor(j * 2 ** (i - 1) / J)) * (2 ** (i - 1) - np.floor(j * 2 ** (i - 1) / J + 1 / 2)) \
                        * qfunc((2 * j + 1) * np.sqrt(6 * np.log2(I * J) * EbNoLin / (I ** 2 + J ** 2 - 2)))
                
                berk = berk * 2 / J
                berJ = berJ + berk
            
            ber = (berI + berJ) / np.log2(I * J)
        return ber, ser
        