https://github.com/keras-team/keras/blob/master/keras/layers/convolutional_recurrent.py
从keras的源码学习一下convLSTM的实现,有助于理解convLSTM的原理
ConvLSTM把LSTM中的全连接操作换成了卷积的形式,可以更好的提取出图像的特征
https://keras.io/api/layers/recurrent_layers/conv_lstm2d/
tf.keras.layers.ConvLSTM2D(
filters,
kernel_size,
strides=(1, 1),
padding="valid",
data_format=None,
dilation_rate=(1, 1),
activation="tanh",
recurrent_activation="hard_sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
bias_initializer="zeros",
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
dropout=0.0,
recurrent_dropout=0.0,
**kwargs
)
5D tensor with shape: (samples, time, channels, rows, cols) - If data_format='channels_last' 5D tensor with shape: (samples, time, rows, cols, channels)
Output shape
https://keras.io/examples/vision/conv_lstm/
# Construct the input layer with no definite frame size.
inp = layers.Input(shape=(None, *x_train.shape[2:]))
# We will construct 3 `ConvLSTM2D` layers with batch normalization,
# followed by a `Conv3D` layer for the spatiotemporal outputs.
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(5, 5),
padding="same",
return_sequences=True,
activation="relu",
)(inp)
上面的公式对应到代码里: inputs_i,inputs_f,inputs_c,inputs_o 在不使用dropout时相等,均为输入X hidden state 隐状态h_tm同上
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
i = self.recurrent_activation(x_i + h_i)
f = self.recurrent_activation(x_f + h_f)
c = f * c_tm1 + i * self.activation(x_c + h_c)
o = self.recurrent_activation(x_o + h_o)
h = o * self.activation(c)
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
# dropout matrices for input units
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
# dropout matrices for recurrent units
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
h_tm1, training, count=4)
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
(kernel_i, kernel_f, kernel_c, kernel_o) = tf.split(
self.kernel, 4, axis=self.rank + 1)
(recurrent_kernel_i, recurrent_kernel_f, recurrent_kernel_c,
recurrent_kernel_o) = tf.split(
self.recurrent_kernel, 4, axis=self.rank + 1)
if self.use_bias:
bias_i, bias_f, bias_c, bias_o = tf.split(self.bias, 4)
else:
bias_i, bias_f, bias_c, bias_o = None, None, None, None
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
i = self.recurrent_activation(x_i + h_i)
f = self.recurrent_activation(x_f + h_f)
c = f * c_tm1 + i * self.activation(x_c + h_c)
o = self.recurrent_activation(x_o + h_o)
h = o * self.activation(c)
return h, [h, c]
@property
def _conv_func(self):
if self.rank == 1:
return backend.conv1d
if self.rank == 2:
return backend.conv2d
if self.rank == 3:
return backend.conv3d
def input_conv(self, x, w, b=None, padding='valid'):
conv_out = self._conv_func(
x,
w,
strides=self.strides,
padding=padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if b is not None:
conv_out = backend.bias_add(conv_out, b, data_format=self.data_format)
return conv_out
def recurrent_conv(self, x, w):
strides = conv_utils.normalize_tuple(
1, self.rank, 'strides', allow_zero=True)
conv_out = self._conv_func(
x, w, strides=strides, padding='same', data_format=self.data_format)
return conv_out
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。