前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >轻松理解Keras回调

轻松理解Keras回调

作者头像
云水木石
发布2019-08-09 23:29:12
1.8K0
发布2019-08-09 23:29:12
举报

随着计算机处理能力的提高,人工智能模型的训练时间并没有缩短,主要是人们对模型精确度要求越来越高。为了提升模型精度,人们设计出越来越复杂的深度神经网络模型,喂入越来越海量的数据,导致训练模型也耗时越来越长。这就如同PC产业,虽然CPU遵从摩尔定律,速度越来越快,但由于软件复杂度的提升,我们并没有感觉计算机运行速度有显著提升,反而陷入需要不断升级电脑硬件的怪圈。

不知道大家有没有这种经历,准备数据,选择好模型,启动训练,训练了一天之后,却发现效果不理想。这个时候怎么办?通常调整几个超参数,重新训练,这样折腾几个来回,可能一个星期,甚至一个月的时间就过去了。如果缺少反馈,训练深度学习模型就如同开车没有刹车一样。

这个时候,就需要了解训练中的内部状态以及模型的一些信息,在Keras框架中,回调就能起这样的作用。在本文中,我将介绍如何使用Keras回调(如ModelCheckpoint和EarlyStopping)监控和改进深度学习模型。

什么是回调

Keras文档给出的定义为:

回调是在训练过程的特定阶段调用的一组函数,可以使用回调来获取训练期间内部状态和模型统计信息的视图。

你可以传递一个回调列表,同时获取多种训练期间的内部状态,keras框架将在训练的各个阶段回调相关方法。如果你希望在每个训练的epoch自动执行某些任务,比如保存模型检查点(checkpoint),或者希望控制训练过程,比如达到一定的准确度时停止训练,可以定义回调来做到。

keras内置的回调很多,我们也可以自行实现回调类,下面先深入探讨一些比较常用的回调函数,然后再谈谈如何自定义回调。

EarlyStopping

从字面上理解, EarlyStopping 就是提前终止训练,主要是为了防止过拟合。过拟合是机器学习从业者的噩梦,简单说,就是在训练数据集上精度很高,但在测试数据集上精度很低。解决过拟合有多种手段,有时还需要多种手段并用,其中一种方法是尽早终止训练过程。EarlyStopping 函数有好几种度量参数,通过修改这些参数,可以控制合适的时机停止训练过程。下面是一些相关度量参数:

  • monitor: 监控的度量指标,比如: acc, val_acc, loss和val_loss等
  • min_delta: 监控值的最小变化。 例如,min_delta = 1表示如果监视值的绝对值变化小于1,则将停止训练过程
  • patience: 没有改善的epoch数,如果过了数个epoch之后结果没有改善,训练将停止
  • restore_best_weights: 如果要在停止后保存最佳权重,请将此参数设置为True

下面的代码示例将定义一个跟踪val_loss值的EarlyStopping函数,如果在3个epoch后val_loss没有变化,则停止训练,并在训练停止后保存最佳权重:

代码语言:javascript
复制
from keras.callbacks import EarlyStopping
earlystop = EarlyStopping(monitor = 'val_loss',
                          min_delta = 0,
                          patience = 3,
                          verbose = 1,
                          restore_best_weights = True)
ModelCheckpoint

此回调用于在训练周期中保存模型检查点。保存检查点的作用在于保存训练中间的模型,下次在训练时,可以加载模型,而无需重新训练,减少训练时间。它有以一些相关参数:

  • filepath: 要保存模型的文件路径
  • monitor: 监控的度量指标,比如: acc, val_acc, loss和val_loss等
  • save_best_only: 如果您不想最新的最佳模型被覆盖,请将此值设置为True
  • save_weights_only: 如果设为True,将只保存模型权重
  • mode: auto,min或max。 例如,如果监控的度量指标是val_loss,并且想要最小化它,则设置mode =’min’。
  • period: 检查点之间的间隔(epoch数)。

示例:

代码语言:javascript
复制
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath,
                             monitor='val_loss',
                             mode='min',
                             save_best_only=True,
                             verbose=1)
LearningRateScheduler

在深度学习中,学习率的选择也是一件让人头疼的事情,值选择小了,可能会收敛缓慢,值选大了,可能会导致震荡,无法到达局部最优点。后来专家们设计出一种自适应的学习率,比如在训练开始阶段,选择比较大的学习率值,加速收敛,训练一段时间之后,选择小的学习率值,防止震荡。LearningRateScheduler 用于定义学习率的变化策略,参数如下:

  • schedule: 一个函数,以epoch数(整数,从0开始计数)和当前学习速率,作为输入,返回一个新的学习速率作为输出(浮点数)。
  • verbose: 0: 静默模式,1: 详细输出信息。

示例代码:

代码语言:javascript
复制
from keras.callbacks import LearningRateScheduler
scheduler = LearningRateScheduler(lambda x: 1. / (1. + x), verbose=0)
TensorBoard

TensorBoard是TensorFlow提供的可视化工具。

该回调写入可用于TensorBoard的日志,通过TensorBoard,可视化训练和测试度量的动态图形,以及模型中不同图层的激活直方图。

我们可以从命令行启动TensorBoard:

代码语言:javascript
复制
tensorboard --logdir = / full_path_to_your_logs

该回调的参数比较多,大部分情况下我们只用log_dir这个参数指定log存放的目录,其它参数并不需要了解,使用默认值即可:

代码语言:javascript
复制
from keras.callbacks import TensorBoard
tensorboard = TensorBoard(log_dir="logs/{}".format(time()))
自定义回调

创建自定义回调非常容易,通过扩展基类keras.callbacks.Callback来实现。回调可以通过类属性self.model访问其关联的模型。

下面是一个简单的示例,在训练期间保存每个epoch的损失列表:

代码语言:javascript
复制
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

history = LossHistory()
model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])

print(history.losses)

输出结果:

代码语言:javascript
复制
[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]
小结

限于篇幅原因,本文只是介绍了Keras中常用的回调,通过这些示例,想必你已经理解了Keras中的回调,如果你希望详细了解keras中更多的内置回调,可以访问keras文档:

https://keras.io/callbacks/

参考:
  1. Keras Callbacks Explained In Three Minutes
  2. Usage of callbacks
  3. Monitor progress of your Keras based neural network using Tensorboard
本文参与?腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-08-05,如有侵权请联系?cloudcommunity@tencent.com 删除

本文分享自 云水木石 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 什么是回调
  • EarlyStopping
  • ModelCheckpoint
  • LearningRateScheduler
  • TensorBoard
  • 自定义回调
  • 小结
  • 参考:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com