网问答>>求教pytorch的LSTM网络代码问题
问题
未解决

求教pytorch的LSTM网络代码问题

时间:2022-08-07 00:58:46
class RegLSTM(nn.Module): def __init__(self): super(RegLSTM, self).__init__() # 定义LSTM self.rnn = nn.LSTM(input_siz=6, hidden_size=32, hidden_num_layers=1) # 定义回归层网络,输入的特征维度等于LSTM的输出,输出维度为1 self.reg = nn.Sequential(nn.Linear(hidden_size, 1)) def forward(self, x): x, (ht,ct) = self.rnn(x) seq_len, batch_size, hidden_size= x.shape x = y.view(-1, hidden_size) x = self.reg(x) x = x.view(seq_len, batch_size, -1) return x求教x, (ht,ct) 其中的 ht,ct是做什么的?
本类最有帮助
Copyright © 2008-2013 www.wangwenda.com All rights reserved.冀ICP备12000710号-1
投诉邮箱: