添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
多情的莴苣  ·  ml-mastery-zh/docs/dlt ...·  1 月前    · 
爱搭讪的猴子  ·  进程调度·  3 周前    · 
耍酷的沙滩裤  ·  C# MarshalAs - ...·  1 月前    · 
私奔的牛腩  ·  Solved: Dropbox ...·  11 月前    · 

LSTM层输出中output和hidden

  • output 包含LSTM每个时间步t的输出特征,

  • h_t 表示LSTM最后一层的输出特征。

    在单向LSTM中,output的最后一个时间步维度的输出 output[:, -1, :] 等于hidden;

    在双向LSTM层中,可以通过 拼接 output的最后一个时间步维度正反向的输出,来得到和hidden一样的输出。

    ​另外,注意控制nn.LSTM()中:# batch_first=True ,将喂入LSTM的数据中batchsize维度提前,如果输入维度中batchsize已经在第一个维度,故无需设置

构建简单模型验证

构建简单模型验证

from torch import nn
class Config(object):
    def __init__(self, vocab_size, embed_dim, label_num):
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.label_num = label_num
        self.bidirectional = False
        self.num_directions = 2 if self.bidirectional else 1
        self.hidden_size = 128
        self.num_layer = 1
class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.vocab_size - 1)
        self.lstm = nn.LSTM(config.embed_dim, config.hidden_size, config.num_layer, batch_first= True,  # batch_first=True,将output中batchsize维度提前,hidden不受影响
                            bidirectional= config.bidirectional)
        self.fc = nn.Linear(config.hidden_size * config.num_directions, config.label_num)
    def forward(self, input):
        embed = self.embedding(input)  
        lstm_out, (hidden, cell) = self.lstm(embed)
        output = lstm_out[:, -1, :]
        return hidden, output
import torch
import numpy as np
vocab_size = 100
embed_dim = 64
label_num = 2
epoch = 40
config = Config(vocab_size, embed_dim, label_num)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(config).to(device)
print(model)
x = abs(np.random.randn(128, 200))
print(x)
datas =torch.from_numpy(x).long().to(device)
hidden, output = model(datas)
hidden==output
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])
LSTM结构(右图)与普通RNN(左图)的主要输入输出区别如下所示
相比RNN只有一个传递状态h^t, LSTM有两个状态,一个c^t(cell state)理解为长时期记忆,和一个h^t(hidden state)理解为短时强记忆。
其对于传递下去的c^t 改变得很慢,通常输出的c^t 是上一个状态传过来的c^(t-1)加上一些...
				
pytorch LSTMoutputhidden关系1.LSTM模型简介2.pytorchLSTM3.关于h和output之间的关系进行实验 1.LSTM模型简介 能点进来的相信大家也都清楚LSTM是个什么东西,我在这里就不加赘述了。具体介绍模型结构的也有不少。 如何简单的理解LSTM——其实没有那么复杂 人人都能看懂的LSTM 2.pytorchLSTM 这里附上一张pytorch官方文档的截图,h_n和c_n我都理解分别是上图横向箭头的下方箭头和上方的箭头,那output是干什么用的?
官方文档:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html h_n:最后一个时间步的输出,即 h_n = output[:, -1, :](一般可以直接输入到后续的全连接,在 Keras 通过设置参数 return_sequences=False 获得) c_n:最后一个时间步 LSTM cell 的状态(一般用不到) 实例:根据红框可以直观看出,h_n 是最后一个时间步的输出,即是 h_n = ou LSTM的输入 input, (h_0, c_0) – input (seq_len, batch, input_size) – h_0 (num_layers * num_directions, batch, hidden_size) # 初始的隐藏状态 – c_. # num_directions=0, 表示前向结果 # num_directions=1, 表示反向结果 output.view(seq_len, batch, num_directions, hidden_size) concat输出 https://blog.csdn.net/qq_27061325/article/details/89463460 hidden_size = 128 number_layer = 3 input = torch.randint(low=0,high=256,size=[batch_size,seq_len]) #[64,20] embedding = nn.Embedding(num_embeddings,embedding_dim) input_embeded = embeddin 单lstm lstm = nn.LSTM(input_size=100, hidden_size=200, bidirectional=True, batch_first=True) a = torch.randn(32, 512, 100) out, (h, c) = lstm(a) print(out.shape) # 32, 512, 400 print(h.shape) # 2, 32, 400 print(out[0 Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g.
在双向LSTM,通过拼接output的最后一个时间步维度正反向的输出,可以得到和hidden一样的输出。而在单向LSTMoutput的最后一个时间步维度的输出output[:, -1, :]等于hiddenLSTM是一种循环神经网络的变体,通过使用一组基于用户编辑历史模式的特征来识别未公开付费编辑。实验评估结果显示,该方法的AUROC为0.93,平均精度为0.90,优于现有方法。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [LSTM输出outputhidden](https://blog.csdn.net/l_aiya/article/details/126412008)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [检测维基百科未公开的付费编辑(计算机硕士论文英文参考资料).pdf](https://download.csdn.net/download/weixin_44609920/88240778)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]