import numpy as np
import torch
from ...fiber_simulation.comm_tools import calculation as calcu
from ...fiber_simulation.tx_dsp.shaping import shaping
from ...fiber_simulation.comm_tools.normalization import norm_1d, norm_dual_pol
from ...fiber_simulation.comm_tools.resample.resample_fft import resample
from ...fiber_simulation.comm_tools.filter_design import filter_design


def tx(sig_in, para, plot_para, ** kwargs):
    r""" Tx DSP Each function except for tx_nn_eq return a filter in Freq. domain
    """
    data_mode, device = para.tx_data_mode, para.device
    print_flag, fig_plot = para.infor_print, para.fig_plot
    if data_mode == 'tensor':
        para.ps_obj.shaping_filter = torch.ones(para.fft_num, device=device)
    else:
        para.ps_obj.shaping_filter = np.ones(para.fft_num)
    if para.pulse_shaping:
        para.ps_obj.get_freq(para.fft_num)
        para.ps_obj.shaping_filter = para.ps_obj.shaping_filter * para.ps_obj.h 
        if print_flag:
            print(' --Pulse shaping finish--')
    
    x = sig_in[0]
    x = shaping(x, para.ps_obj.upsam, para.ps_obj)
    x = resample(x, rate = para.sam_rate / para.ps_obj.sam_rate)
    y = sig_in[1]
    y = shaping(y, para.ps_obj.upsam, para.ps_obj)
    y = resample(y, rate = para.sam_rate / para.ps_obj.sam_rate)
    
    if fig_plot:
        plot_para.scatter_plot_nPol([x, y], name = 'Constellation_After_TXDSP', set_c = 1)
    
    dac_iput = [x, y]

    return dac_iput


