# @Author ：Hang Yang 
# @Datatime：2021/04/28 9:49
# @File: model_BiLSTM.py
# @Last Modify Time: 2021/04/28 9:49
# @Contact : hangyang@sjtu.edu.cn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

'''
Neural Networks model : BiLSTM
'''

class BiLSTM(nn.Module):
    def __init__(self, args):
        super(BiLSTM, self).__init__()
        
        self.rnn = nn.LSTM(args.BiLSTMinDim, args.BiLSTMoutDim, args.layers, batch_first = True, bidirectional = True) # rnn,hidden dim=outdim
        self.reg = nn.Linear(args.BiLSTMoutDim*2*args.time_step, args.outdim) # 回归 乘2表示双向, ×time_step 表示RNN输出seq长度
        # self.l1 = nn.Linear(args.outdim, 1)

    def forward(self, x):
        # z = self.l1(x)
        # x = torch.from_numpy(x).float()
        x, _ = self.rnn(x) # ( batch, seq, hidden)
        b, s, h = x.shape
        x = x.reshape(b,s*h) # 转换成线性层的输入格式
        x = self.reg(x)
        return x
