本文主要是介绍convLSTM 理解与实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
本文主要是有关convLSTM的pytorch实现代码的理解,原理请移步其他博客。
在pytorch中实现LSTM或者GRU等RNN一般需要重写cell,每个cell中包含某一个时序的计算,也就是以下:
在传统LSTM中,LSTM每次要调用t次cell,t就是时序的总长度,如果是n层LSTM就相当于一共调用了n*t次cell
class ConvLSTMCell(nn.Module):def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):"""Initialize ConvLSTM cell.Parameters----------input_size: (int, int)Height and width of input tensor as (height, width).input_dim: intNumber of channels of input tensor.hidden_dim: intNumber of channels of hidden state.kernel_size: (int, int)Size of the convolutional kernel.bias: boolWhether or not to
这篇关于convLSTM 理解与实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!