from itertools import cycle
import torch
import numpy as np
import warnings
import time
from IFTS.fiber_simulation.comm_tools.resample.resample_fft import resample
from IFTS.fiber_simulation.channel.channel_trans import awgn
from IFTS.fiber_simulation.utils.show_progress import progress_info_return
from IFTS.fiber_simulation.channel.channel_trans.fiber.ssfm.ssfm_matrix import linearity_matrix

def check_power(x, p_indBm, channel_num):
    p = (10 ** ((p_indBm - 30 ) / 10)) * channel_num
    num = x[0].shape[0]
    p_s = 0.0
    for i_p in range(len(x)):
        if type(x[i_p]) == torch.Tensor:
            p1 = (x[i_p].abs() ** 2).sum().item()
        else:
            p1 = np.sum(np.abs(x[i_p] ** 2))  
        p_s += p1 / num
    if np.abs(p - p_s) > p / 10:
        RuntimeWarning('Power output is different with target power')
    
def fiber_ssfm_trans(x, rand_seed, para):
    @progress_info_return(total_num = para.span_num,\
        infor_print = para.infor_print, discription = ' Fiber SSFM')
    def main(x, **kwargs):
        pbar = kwargs.get('pbar', 0)
        i_total = 0
        save_flag = para.save_data
             
        for i in range(para.span_num): 
            x = para.ssfm_obj(x, para.len_array[i],rand_seed+10648*i+1)
            x = para.edfa_obj(x, rand_seed+302*i+1)
            if save_flag:
                name = 'after_span'+str(i+1)
                para.save_data_func_npol(x, name)
            i_total += 1
           
            if type(pbar) != type(0):
                pbar.update(1)
        return x
    x = main(x)
    return x

def fiber_nn_trans(x, rand_seed, para):
    if para.nn_model=='BiLSTM':
        cycle_num = para.span_num
    elif para.nn_model=='GAN':
        cycle_num = 1
    @progress_info_return(total_num = cycle_num,\
        infor_print = para.infor_print, discription = ' Fiber NN')
    def main(x, **kwargs):
        pbar = kwargs.get('pbar', 0)
        i_total = 0
        save_flag = para.save_data
        if para.fiber_config['nn_model']=='BiLSTM':
            for i in range(cycle_num): 
                x = para.nn_obj(x, para.len_array[i])
                x = para.edfa_obj(x, rand_seed+302*i+1)
                if save_flag:
                    name = 'after_span'+str(i+1)
                    para.save_data_func_npol(x, name)
                i_total += 1
                if type(pbar) != type(0):
                    pbar.update(1)
        elif para.fiber_config['nn_model']=='GAN':
            x = para.nn_obj(x, para.total_len)
            if save_flag:
                name = 'after_span'+str(para.span_num)
                para.save_data_func_npol(x, name) 
            i_total += 1
            if type(pbar) != type(0):
                    pbar.update(1)   
        return x
    x = main(x)
    return x

def tx_laser(rx_sig, rand_seed, para):
    @progress_info_return(total_num = para.channel_num,\
        infor_print = para.infor_print, discription = ' Tx laser')
    def main(x, **kwargs):
        pbar = kwargs.get('pbar', 0)
        i_total = 0
        for i in range(para.channel_num): 
            x[i] = para.tx_laser_obj(x[i], rand_seed = rand_seed+i*302)
            i_total += 1
            if type(pbar) != type(0):
                pbar.update(1)
        return rx_sig
    rx_sig = main(rx_sig)
    return rx_sig

def channel_transmission(x, para, plot_para=None):
    print_flag = para.infor_print
    plot_flag = para.fig_plot
    save_flag = para.save_data
    data_mode, device = para.ch_data_mode, para.device
    plot_para.sam_rate = para.sam_rate
    if para.ch_random or para.rand_seed <= 0:
        rand_seed = -501
        if print_flag:
            print(' Channel random effects are non-determinstic')
    else:
        rand_seed = para.rand_seed
        if print_flag:
            print(' Channel random effects are determinstic')
    start_time = time.time()

    if para.dac:
        """
            Resample to the channel sample rate
        """
        if para.infor_print:
            print(' Resample to channel sample rate')

        for i_c in range(para.channel_num):
            for i_p in range(para.nPol):
                x[i_c][i_p] = resample(x[i_c][i_p], rate = para.upsam)    
    

    '''
        Optical signal transmission
    '''

    x = para.wdm_obj.wdm_multiplexing(x)
    if plot_flag:
        f_light = para.optical_carrier*10**-3
        lim = (1.1*para.channel_space * para.channel_num/2)*10**-3 
        xlim = [-lim+f_light, lim+f_light]
        plot_para.psd(x[0], name = 'WDM_PSD', \
            fc = f_light, hz = 'THz', xlim=xlim)
    if save_flag:
        pol_name = ['x', 'y']
        for i_p in range(para.nPol):
            name = 'after_wdm_'+ pol_name[i_p]
            para.save_data_func(x[i_p], name, data_mode = para.wdm_data_mode) 
    check_power(x, para.sig_power_dbm, para.channel_num)
    if para.channel_type == 1:
        if para.fiber_mode == 'SSFM':
            if print_flag: 
                print(' Input into the SSFM channel with {} spans and {}km'.format(para.span_num, para.total_len))
            x = fiber_ssfm_trans(x, rand_seed+10137, para)
        elif para.fiber_mode == 'NN':
            if print_flag:
                print(' Input into the NN channel with {} spans and {}km'.format(para.span_num, para.total_len))
            x = fiber_nn_trans(x, rand_seed+10137, para)
    elif para.channel_type == 2:
        if print_flag: 
            print(' Input into the AWGN channel with SNR {} dB'.format(para.snr_db))
        x = awgn.awgn(x, para.snr_db, para.ch_random,rand_seed+10260,device=para.device)
    check_power(x, para.sig_power_dbm, para.channel_num)
    if save_flag:
        pol_name = ['x', 'y']
        for i_p in range(para.nPol):
            name = 'after_fiber_'+ pol_name[i_p]
            para.save_data_func(x[i_p], name, data_mode = para.ch_data_mode)
    x = para.wdm_obj.wdm_demultiplexing(x, para_update=0)
        
    if para.receiver:
        if print_flag:
            print(' Input to the ICR ...')
        x = para.icr_obj(x, rand_seed = rand_seed+22034)
    if para.adc:
        for i_p in range(para.nPol):
            x[i_p] = resample(x[i_p], rate = para.rx_sam_rate/para.ch_sam_rate)   
        if print_flag:
            print(' ADC: Downsampling to {:.0f} Gsam/s'.format(para.rx_sam_rate)) 
    end_time = time.time()
    if print_flag:
        print(' Channel simulation time: %.4f s' % (end_time - start_time))
    return x