import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

'''
Neural Networks model : GAN
'''
class Generator(nn.Module):
    def __init__(self, args):
        super(Generator, self).__init__()

        self.fc1 = nn.Linear(args.noise_dim+args.indim, 512)
        self.fc1_bn = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.fc2_bn = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256,64)
        self.fc3_bn = nn.BatchNorm1d(64)
        self.fc4 = nn.Linear(64, args.outdim)
        self.fc51 = nn.Linear(64, args.outdim)
        self.fc52 = nn.Linear(64, args.outdim)
    
    # forward method
    def forward(self,  input, latent):
        '''
        input: condition input data

        latent: noise
        '''
        x = torch.cat([input, latent], 1)
        x = F.leaky_relu(self.fc1_bn(self.fc1(x)))
        # x=x+torch.randn(x.size())
        #x = F.relu(self.fc1_2_bn(self.fc1_2(conditional)))
        x = F.leaky_relu(self.fc2_bn(self.fc2(x)))
        # x=self.fc23(x)
        # x=x+torch.randn(x.size())
        h = F.leaky_relu(self.fc3_bn(self.fc3(x)))
        # x=x+torch.randn(x.size())
        s = self.fc4(h)
        return s