import IFTS.fiber_simulation as f
from IFTS.fiber_simulation.visualization import plot_constellation as plt_con
from IFTS.fiber_simulation.visualization import plot_err_analysis as plt_analysis
from IFTS.fiber_simulation.visualization import plot_wave as plt_wav
from IFTS.fiber_simulation.visualization import plot_power_density as plt_psd
from IFTS.simulation_main.modul_para.signal_para import Sig_Para
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

class Plot_Para(Sig_Para):
    def __init__(self, rand_seed, config_path):
        super().__init__(rand_seed, config_path)   
        self.front_sym_num = int(self.padding_num / 2)
        # self.front_sym = 0
        if self.sym_num_per_pol > 100000:
            self.constellation_points = 100000
        else:
            self.constellation_points = self.sym_num_wo_padding

        if self.constellation_points + self.padding_num > self.sym_num_per_pol:
            self.constellation_points = self.sym_num_per_pol - self.padding_num
        # self.path = './figure/'
        # self.path = '/home/ospan/Data/Niu/results/simulation/'
        self.path = self.figure_path
        self.scatter_colour = 'gold'

    def __para_init__(self):
        self.front_sym_num = int(self.padding_num / 2)
        # self.front_sym = 0
        self.constellation_points = self.sym_num_per_pol
        if self.constellation_points + self.padding_num > self.sym_num_per_pol:
            self.constellation_points = self.sym_num_per_pol - self.padding_num
        # self.path = './figure/'
        # self.path = '/home/ospan/Data/Niu/results/simulation/'
        self.path = self.figure_path
        self.scatter_colour = 'gold'
        
    def data_mode_check(self, x):   
        x_type = type(x)
        if x_type is torch.Tensor:
            return x.clone().detach().cpu().numpy()
        else:
            return x
    
    def get_max_value(self, x):
        if type(x) == np.ndarray:
            return max([np.max(np.abs(x.real)),\
                 np.max(np.abs(x.imag))])
        else:
            return max([torch.max((x.real).abs()).item(),\
                torch.max((x.imag).abs()).item()])

    def get_colour(self, x, sam_num = 1):
        idx = self.get_idx(sam_num)
        nPol = len(x) 
        c = []
        colour = np.array(sns.color_palette("husl", self.class_num))
        for i_p in range(nPol):
            c.append(colour[x[i_p][idx]])
        self.colour_seq = c
        self.colour_index = np.arange(idx.shape[0])
        # return c 
    
    def get_idx(self, sam_num):
        return (self.front_sym_num + np.arange(self.constellation_points)) * sam_num

    def scatter_plot_nPol(self, x, sam_num = 1, name = 'Constellation', set_c = 0, xlim = None, ylim = None, s = None, c_index = None):
        nPol = len(x)
        pol_name = ["X Polarization", "Y Polarization"]
        fig, axs = plt.subplots(1, nPol, figsize=(6.5* nPol, 6.5), edgecolor = '#282C34')
        fig.set_facecolor('#282C34')
        fig.set_edgecolor('#282C34')
        fig.suptitle(name, color = '#B4B6BD', fontsize = 'xx-large')
        # fig, axs = plt.subplots(1, nPol, figsize=(6.5* nPol, 6.5))
        # fig.suptitle(name, fontsize = 'xx-large')
        if xlim == None or ylim == None:
            xy_max = np.max(np.array([self.get_max_value(x[i_p]) for i_p in range(nPol)]))
            xlim = [-xy_max, xy_max]
            ylim = [-xy_max, xy_max]
        idx = self.get_idx(sam_num = sam_num)
        for i_p in range(nPol):
            y = x[i_p][idx]
            y = self.data_mode_check(y)
            if set_c:
                c = self.scatter_colour
            else:
                if not hasattr(self, 'colour_seq') or len(self.colour_seq) == 0:
                    c = None
                else:
                    self.colour_index = np.arange(idx.shape[0])
                    c = self.colour_seq[i_p][self.colour_index]
            # axs[i_p] = plt_con.scatter_plot_white(y, axs[i_p], pol_name[i_p], c, xlim, ylim, s)
            axs[i_p] = plt_con.scatter_plot_black(y, axs[i_p], pol_name[i_p], c, xlim, ylim, s)

        fig.savefig(self.path + name + '.png', dpi = 500)
        plt.close('all')

    def scatter_plot(self, x, sam_num = 1, name = 'Constellation', set_c = 0, xlim = None, ylim = None, s = None, c_index = None):
        nPol = len(x)
        pol_name = ["X Polarization"]
        fig, axs = plt.subplots(1, 1, figsize=(6.5, 6.5))
        # fig.suptitle(name, color = '#B4B6BD', fontsize = 'xx-large')
        idx = self.get_idx(sam_num = sam_num)
        # for i_p in range(nPol):
        y = x[0][idx]
        y = self.data_mode_check(y)
        if set_c:
            c = self.scatter_colour
        else:
            if not hasattr(self, 'colour_seq') or len(self.colour_seq) == 0:
                c = None
            else:
                c = self.colour_seq[0][self.colour_index]
            axs = plt_con.scatter_plot(y, axs, pol_name[0], c, xlim, ylim, s)

        fig.savefig(self.path + name + '.png', dpi = 900)
        plt.close('all')
    
    def scatter_heat_map_nPol(self, x, sam_num = 1, name = 'Constellation', xlim = None, ylim = None, s = None):
        nPol = len(x)
        pol_name = ["X Polarization", "Y Polarization"]
        fig, axs = plt.subplots(1, nPol, figsize=(6.5* nPol, 6.5), edgecolor = '#282C34')
        fig.set_facecolor('#282C34')
        fig.set_edgecolor('#282C34')
        fig.suptitle(name, color = '#B4B6BD', fontsize = 'xx-large')
        idx = self.get_idx(sam_num = sam_num)
        for i_p in range(nPol):
            y = x[i_p][idx]
            y = self.data_mode_check(y)
            axs[i_p] = plt_con.scatter_heat_map(y, axs[i_p], pol_name[i_p], xlim, ylim, s)

        fig.savefig(self.path + name + '.png', dpi = 500) 
        plt.close('all')
    
    def psd(self, x, fc=0, hz = 'GHz', name = 'power_density', xlim=None, ylim=None):
        x = self.data_mode_check(x)
        fig, axs = plt.subplots(1, figsize=(6.5, 6.7))
        fig.suptitle(name, fontsize = 'xx-large')
        if hz == 'GHz':
            sam_rate = self.sam_rate
            label = 'Frequency(GHz)'
        else:
            sam_rate = self.sam_rate * 10 ** -3
            label = 'Frequency(THz)'  
        psd, f = axs.psd(x, 4096, Fs = sam_rate, Fc = fc)
        psd_log = 10 * np.log10(psd)
        axs.grid(True)
        axs.set_ylabel('')
        axs.set_xlabel(label)
        if xlim is not None:
            axs.set_xlim(xlim)
        if ylim is not None:
            axs.set_ylim(ylim)
        else:
            axs.set_ylim([np.max(psd_log) - 60, np.max(psd_log)+1])
        axs.axvline(x=fc, ls='--', color='r')
        axs.text(x=fc+0.01, y = np.max(psd_log)-58,\
             s = str(round(fc,2))+' THz', color='r')
        fig.savefig(self.path + name + '.png', dpi = 500) 
        plt.close('all')
        
    def psd_nPol(self, x, fc=0, hz = 'GHz', name = 'power_density', sign = None):
        nPol = len(x)
        pol_name = ["X Polarization", "Y Polarization"]
        fig, axs = plt.subplots(1, nPol, figsize=(6.5* nPol, 6.5))
        fig.suptitle(name, fontsize = 'xx-large')
        for i_p in range(nPol):
            y = x[i_p]
            y = self.data_mode_check(y)
            if hz == 'GHz':
                sam_rate = self.sam_rate
                label = 'Frequency(GHz)'
            else:
                sam_rate = self.sam_rate * 10 ** -3
                label = 'Frequency(THz)'  
            psd, f = axs[i_p].psd(y, 4096, Fs = sam_rate, Fc = fc)
            psd_log = 10 * np.log10(psd)
            axs[i_p].grid(True)
            axs[i_p].set_ylabel('')
            axs[i_p].set_xlabel(label)
            if sign is not None:
                axs[i_p].axhline(y=np.max(psd_log)-10)
                axs[i_p].axvline(x=sign)
                axs[i_p].text(x = sign-5, y = np.max(psd_log)-9, s = '(' + str(np.around(sign, 1)) +' GHz, -10 dB)')
                axs[i_p].set_ylim([np.max(psd_log) - 60, np.max(psd_log)+1])
                axs[i_p].set_xlim([-sign*1.5, sign*1.5])
            axs[i_p].set_title(pol_name[i_p], fontsize = 'x-large')
        fig.savefig(self.path + name + '.png', dpi = 500) 
        plt.close('all')

    def err_hist_plot_nPol(self, sym_map, para):
        nPol = para.nPol
        pol_name = ["X", "Y"]
        for i_p in range(nPol):
            plt_analysis.err_analysis_plot(para.sym_err[i_p], para.sym_err_idx[i_p], sym_map, name = self.path + pol_name[i_p] + 'hist')
        plt.close('all')

    def err_analysis_plot(self, rx_sym, tx_sym_idx, err_sym, err_sym_idx, sym_map, i_p):
        pol_name = ["X", "Y"]
        if err_sym_idx.shape[0] >= 256:
            plt_analysis.err_analysis_plot_v2(rx_sym, tx_sym_idx, err_sym, err_sym_idx, sym_map, name = self.path + pol_name[i_p] + 'hist')    
        else:
            print('Few symbol errors of ' + pol_name[i_p] + ' are not allowed to plot error histgram here')
            # SNR of each constellation points are allowed only
            plt_analysis.err_analysis_plot_v2(rx_sym, tx_sym_idx, err_sym, err_sym_idx, sym_map, hist = 0, name = self.path + pol_name[i_p] + 'hist')    
        plt.close('all')
    
    def loss_plot_nPol(self, loss_x, loss_y, name = 'Loss function'):
        fig, axs = plt.subplots(1, 2, figsize=(8 * 2, 6.5))
        fig.suptitle(name, fontsize = 'xx-large')
        loss_x = self.data_mode_check(loss_x)
        loss_y = self.data_mode_check(loss_y)
        axs[0] = plt_wav.wave1d_plot(np.abs(loss_x[::10]), axs[0], \
            title = 'loss x', smooth=1, EMA_SPAN = 40)
        axs[1] = plt_wav.wave1d_plot(np.abs(loss_y[::10]), axs[1], \
            title = 'loss y', smooth=1, EMA_SPAN = 40)
        axs[0].set_ylim(0, 1.5)
        axs[1].set_ylim(0, 1.5)
        
        fig.savefig(self.path + name + '.png', dpi = 500)
        plt.close('all')
    def firtap_plot_nPol_old(self, mimo_obj, name = 'Firtap'):
        if mimo_obj.fir_type == '2x2':
            plot_num = 4
            subname = ['hxx', 'hxy', 'hyx', 'hyy']
            if mimo_obj.mat_op:
                x = mimo_obj.h.transpose((2, 0, 1)).reshape((plot_num, -1))
            else:
                x = [mimo_obj.hxx, mimo_obj.hxy,\
                    mimo_obj.hyx, mimo_obj.hyy]
            fig, axs = plt.subplots(2, 2, figsize=(6.5* 2, 6.5 * 2))
        elif mimo_obj.fir_type == '4x2':
            plot_num = 8
            if mimo_obj.mat_op:
                x = mimo_obj.h.transpose((2, 0, 1)).reshape((plot_num, -1))
            else:
                x = [mimo_obj.hx_xi, mimo_obj.hx_xq, mimo_obj.hx_yi, mimo_obj.hx_yq,\
                    mimo_obj.hy_xi, mimo_obj.hy_xq, mimo_obj.hy_yi, mimo_obj.hy_yq]
            subname = ['hx_xi', 'hx_xq', 'hx_yi', 'hx_yq', \
            'hy_xi', 'hy_xq', 'hy_yi', 'hy_yq']
            fig, axs = plt.subplots(2, 4, figsize=(6.5* 4, 6.5 * 2))
        
        sns.set_theme(style="darkgrid")
        axs = axs.reshape(-1)
        fig.suptitle(name, fontsize = 'xx-large')
        for i in range(plot_num):
            axs[i] = plt_wav.wave1d_plot(np.real(x[i]), axs[i], title = subname[i], smooth=0)
            axs[i] = plt_wav.wave1d_plot(np.imag(x[i]), axs[i], title = subname[i], smooth=0)
        fig.savefig(self.path + name + '.png', dpi = 500)
        plt.close('all')
    def firtap_plot_nPol(self, mimo_obj, name = 'Firtap'):
        if mimo_obj.fir_type == '2x2':
            plot_num = 4
            subname = ['hxx', 'hxy', 'hyx', 'hyy']
            fig, axs = plt.subplots(2, 2, figsize=(6.5* 2, 6.5 * 2))
        elif mimo_obj.fir_type == '4x2':
            plot_num = 8
            subname = ['hx_xi', 'hx_xq', 'hx_yi', 'hx_yq', \
            'hy_xi', 'hy_xq', 'hy_yi', 'hy_yq']
            fig, axs = plt.subplots(2, 4, figsize=(6.5* 4, 6.5 * 2))
        x = mimo_obj.h[..., 0].reshape((plot_num, -1))
        x = self.data_mode_check(x)
        
        sns.set_theme(style="darkgrid")
        axs = axs.reshape(-1)
        fig.suptitle(name, fontsize = 'xx-large')
        for i in range(plot_num):
            axs[i] = plt_wav.wave1d_plot(np.real(x[i]), axs[i], title = subname[i], smooth=0)
            axs[i] = plt_wav.wave1d_plot(np.imag(x[i]), axs[i], title = subname[i], smooth=0)
        fig.savefig(self.path + name + '.png', dpi = 500) 
        plt.close('all')
    def wave_plot_nPol(self, x, name = 'Wave'):
        plot_num = len(x)
        sns.set_theme(style="darkgrid")
        subname = ['x', 'y']
        fig, axs = plt.subplots(1, plot_num, figsize=(6.5* plot_num, 6.5))
        fig.suptitle(name, fontsize = 'xx-large')
        for i in range(plot_num):
            y = self.data_mode_check(x[i])
            axs[i] = plt_wav.wave1d_plot(y, axs[i], title = subname[i], smooth=0)
           
        fig.savefig(self.path + name + '.png') 
        plt.close('all')
    def corr_plot(self, x, name = 'Symchronization'):
        row_num = x.shape[0]
        column_num = x.shape[1]
        if row_num == 2:
            subname = ['corr_xx', 'corr_xy', 'corr_yx', 'corr_yy']

        elif row_num == 4:
            subname = ['corr_xrxr', 'corr_xrxi', 'corr_xryr', 'corr_xryi',\
                        'corr_xixr', 'corr_xixi', 'corr_xiyr', 'corr_xiyi',\
                            'corr_yrxr', 'corr_yrxi', 'corr_yryr', 'corr_yryi',\
                                'corr_yixr', 'corr_yixi', 'corr_yiyr', 'corr_yiyi']

        fig, axs = plt.subplots(row_num, column_num, figsize=(5* row_num, 5 * column_num))
        sns.set_theme(style="darkgrid")
        fig.suptitle(name, fontsize = 'xx-large')
        for i in range(row_num):
            for j in range(column_num):
                axs[i,j] = plt_wav.wave1d_plot(x[i,j], axs[i,j], title = subname[i*column_num + j], smooth=0)
                axs[i,j].set_xticks([])

        fig.savefig(self.path + name + '.png', bbox_inches = 'tight') 
        plt.close('all')