首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

How to Retrain Inception's Final Layer for New Categories(如何为新类别重新训练盗梦空间的最后一层)

现代物体识别模型有数百万个参数,可能需要数周才能完成训练。迁移学习是一种技巧,通过为ImageNet等一系列类别提供完全培训的模型,并从现有的新类权重中重新训练,从而缩短了大量工作量。在这个例子中,我们将从头开始重新训练最后一层,同时保留所有其他层。有关该方法的更多信息,请参阅Decaf的这篇论文

虽然它不如完整运行的训练,但这对于许多应用程序来说是非常有效的,并且可以在笔记本电脑上短短三十分钟内运行,而无需GPU。本教程将向您展示如何在自己的图像上运行示例脚本,并解释一些可帮助控制培训过程的选项。

注意:本教程的这个版本主要使用bazel。一个bazel免费版本也可以作为codelab提供。

Flowers 训练

在开始任何培训之前,您需要一组图像来向网络传授您想要识别的新课程。后面的章节解释了如何准备自己的图片,但为了方便起见,我们创建了花卉照片的存档,以供初始使用。要获取花卉照片集,请运行以下命令:

代码语言:javascript
复制
cd ~
curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz

一旦你有了图像,你可以从你的TensorFlow源代码目录的根目录下建立这样一个修复器:

代码语言:javascript
复制
bazel build tensorflow/examples/image_retraining:retrain

如果您的计算机支持AVX指令集(在过去几年中生产的x86 CPU中常见),您可以通过构建该架构来提高再训练的运行速度,如下所示(在选择适当的选项后configure):

代码语言:javascript
复制
bazel build --config opt tensorflow/examples/image_retraining:retrain

然后可以像这样运行修复器:

代码语言:javascript
复制
bazel-bin/tensorflow/examples/image_retraining/retrain --image_dir ~/flower_photos

脚本加载预先训练的Inception v3模型,删除旧的顶层,并在您下载的花卉照片上训练一个新的顶层。全部网络都没有经过培训的原始ImageNet类别中的花卉种类。转移学习的神奇之处在于已经被训练以区分一些对象的较低层可以被重复用于许多识别任务,而没有任何改变。

瓶颈

该脚本可能需要30分钟或更长时间才能完成,具体取决于机器的速度。第一阶段分析磁盘上的所有图像并计算每个图像的瓶颈值。“瓶颈”是一个非正式的术语,我们经常在最后一个输出层之前使用该层,实际进行分类。这个倒数第二层已经被训练输出一组值,这些值足以让分类器用来区分它被要求识别的所有类。这意味着它必须是对图像有意义和有影响力的总结,因为它必须包含足够的信息才能为分类器在很小的一组值上作出好的选择。

由于每个图像在训练过程中都会重复使用多次,计算每个瓶颈需要花费大量时间,因此它可以加快速度,将这些瓶颈值缓存在磁盘上,因此无需重复计算。默认情况下,它们存储在/tmp/bottleneck目录中,如果您重新运行脚本,它们将被重用,因此您不必再等待此部分。

训练

一旦瓶颈完成,网络顶层的实际培训就开始了。您会看到一系列步骤输出,每个输出都显示训练准确性,验证准确性和交叉熵。训练准确性显示当前训练批次中使用的图像的百分比是否标有正确的分类。验证的准确性是从不同集合中随机选择的一组图像的精度。关键的区别在于,训练的准确性基于网络能够学习的图像,因此网络可以适应训练数据中的噪声。衡量网络性能的一个真正衡量标准是衡量其在训练数据中未包含的数据集上的表现 - 这是通过验证准确度来衡量的。如果训练准确度高但验证准确度仍然较低,那意味着网络过度拟合并记忆训练图像中的特定功能,这些功能通常不会有帮助。交叉熵是一种损失函数,可以让我们看到学习过程的进展情况。训练的目标是让损失尽可能小,因此您可以通过关注损失是否持续下降趋势来判断学习是否奏效,而忽略短期噪音。

默认情况下,此脚本将运行4,000个训练步骤。每个步骤从训练集中随机选择10幅图像,从高速缓存中找出它们的瓶颈,并将它们送入最终图层进行预测。然后将这些预测与实际标签进行比较,以通过反向传播过程更新最终图层的权重。随着过程的继续,您应该看到所报告的准确性提高,并且在所有步骤完成后,将对一组图像进行最终测试准确性评估,并将其与训练和验证图片分开保存。此测试评估是对训练模型如何在分类任务上执行的最佳估计。您应该看到90%到95%之间的准确度值,但由于训练过程中存在随机性,所以确切的值会因运行而异。

用TensorBoard可视化再培训

该脚本包含TensorBoard摘要,使其更易于理解,调试和优化再训练。例如,您可以将图表和统计数据可视化,例如训练期间权重或准确度如何变化。

要启动TensorBoard,请在再训练期间或之后运行此命令:

代码语言:javascript
复制
tensorboard --logdir /tmp/retrain_logs

一旦TensorBoard正在运行,请您的网络浏览器导航localhost:6006以查看TensorBoard。

脚本将/tmp/retrain_logs默认记录TensorBoard摘要。您可以使用该--summaries_dir标志更改目录。

TensorBoard的GitHub上对TensorBoard使用情况,包括提示和技巧,和调试信息的更多的信息。

使用再训练模式

该脚本将写出Inception v3网络的一个版本,最后一层重新训练到/tmp/output_graph.pb的类别,以及一个包含/tmp/output_labels.txt标签的文本文件。这些都是C ++和Python图像分类示例可以读入的格式,因此您可以立即开始使用新模型。由于您已替换顶层,因此您需要在脚本中指定新名称,例如,--output_layer=final_result如果您使用label_image,则使用该标记。

以下是如何使用您的再培训图创建和运行label_image示例的示例:

代码语言:javascript
复制
bazel build tensorflow/examples/image_retraining:label_image && \
bazel-bin/tensorflow/examples/image_retraining/label_image \
--graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
--output_layer=final_result:0 \
--image=$HOME/flower_photos/daisy/21652746_cc379e0eea_m.jpg

你应该看到一个花标签列表,在大多数情况下菊花在上面(尽管每个重新训练的模型可能略有不同)。您可以--image使用自己的图像替换参数以尝试这些参数,并将C ++代码作为模板与您自己的应用程序集成。

如果你想在你自己的Python程序中使用再培训模型,那么上面的label_image脚本是一个合理的起点。

如果您发现默认的Inception v3模型对于您的应用程序太大或太慢,请查看下面的“其他模型架构”部分,以了解如何加速和缩小您的网络。

根据自己的类别进行培训

如果你已经设法让这个脚本处理花图像,你可以开始寻找教它识别你关心的类别。从理论上讲,你需要做的就是将它指向一组子文件夹,每个子文件夹以你的一个类别命名,并且只包含该类别的图像。如果你这样做,并传递子目录的根文件夹作为参数--image_dir,那么脚本应像对花一样进行训练。

以下是鲜花档案的文件夹结构,以便为您提供脚本所寻找布局的示例:

在实践中,可能需要一些工作来获得所需的准确性。我会尽力引导您解决下面可能遇到的一些常见问题。

创建一组训练图像

首先要看看你收集的图像,因为我们通过培训看到的最常见的问题来自所馈入的数据。

为使训练工作顺利,您应该至少收集一百张您想要识别的各种物体的照片。你可以收集的越多,训练好的模型的准确性就越好。您还需要确保照片能够很好地表现您的应用程序实际会遇到的情况。例如,如果您将所有照片放在室内的空白墙壁上,并且用户尝试识别室外的物体,则在部署时可能看不到良好的效果。

另一个要避免的陷阱是, 学习过程将会在任何被标记的图像有共同之处的东西上找到, 如果你不小心, 可能是一些没用的东西。例如,如果您在蓝色房间中拍摄一种物体,而另一种物体拍摄为绿色,则模型最终将根据背景颜色进行预测,而不是您实际关心的物体的特征。为避免这种情况,请尝试在各种情况下尽可能在不同时间和不同设备上拍摄照片。如果你想知道更多关于这个问题的信息,你可以阅读经典的(也可能是伪装的)坦克识别问题

您可能还想考虑您使用的类别。将大量不同的物理形式分成较小的物体可能是值得的,这些小物体在视觉上更加独特。例如,可以使用“汽车”,“摩托车”和“卡车”来代替“车辆”。还值得考虑一下你是否有“封闭世界”或“开放世界”问题。在封闭的世界里,你唯一要求分类的东西就是你所了解的对象类。这可能适用于植物识别应用程序,你知道用户可能正在拍摄一朵花,所以你只需要确定哪些物种。相比之下,漫游机器人可能会通过其摄像头在世界各地漫游时看到各种不同的东西。在那种情况下,希望分类器报告它是否不确定它看到的是什么。这可能很难做到,但是如果您经常收集大量没有相关对象的典型“背景”照片,则可以将它们添加到图像文件夹中额外的“未知”类。

这也值得检查,以确保所有的图像都标有正确的标签。通常,用户生成的标签对于我们的目的不可靠,例如,将#daisy用于名为Daisy的人的照片。如果你仔细检查你的图片并清除任何错误,它可以为你的整体准确性做出奇迹。

训练步骤

如果您对图像满意,可以通过改变学习过程的细节来改善您的结果。最简单的尝试是--how_many_training_steps。默认值为4,000,但如果将其增加到8,000,则训练时间会延长两倍。准确度提高的速度减慢了你训练的时间,并且在某些时候会完全停止,但是你可以试验看看你何时达到了你的模型的极限。

畸变

改善图像训练结果的一种常见方式是以随机方式对训练输入进行变形,裁剪或增亮。这有利于扩大训练数据的有效大小,这要归功于相同图像的所有可能的变化,并且倾向于帮助网络学会应对在分类器的实际使用中将发生的所有失真。在我们的脚本中实现这些扭曲的最大缺点是瓶颈缓存不再有用,因为输入图像永远不会重复使用。这意味着培训过程需要更长的时间,所以我建议您尝试这种方式,以便在您有一个合理满意的模型后对模型进行微调。

您可以通过将启用这些扭曲传递指令--random_crop--random_scale--random_brightness给脚本。这些都是控制每个图像应用了多少扭曲的百分比值。对每个人开始5或10的值是合理的,然后试验看看他们哪些人可以帮助你的应用程序。--flip_left_right将水平地随机镜像一半图像,只要这些反转可能发生在您的应用程序中,这就很有意义。例如,如果你试图识别字母,这不是一个好主意,因为翻转它们会破坏它们的意思。

超参数

还有其他几个参数可以尝试调整,以查看它们是否有助于您的结果。--learning_rate控件更新到最后一层的训练过程中的大小。直观地说,如果这比学习要花费的时间更长,但最终可以帮助整体精度。但情况并非总是如此,所以您需要仔细试验以了解适合您的情况。--train_batch_size控制在一个训练阶段多少图像进行检查,而且由于每批次应用学习率,你需要减少它,如果你有更大的批量获得相同的整体效果。

训练,验证和测试集

当你将脚本指向一个图像文件夹时,脚本在底层做的事情之一是将它们分成三个不同的集合。最大的通常是训练集,它是训练期间输入网络的所有图像,结果用于更新模型的权重。您可能想知道为什么我们不使用所有图像进行培训?当我们进行机器学习时,一个很大的潜在问题是我们的模型可能只是记住训练图像的不相关细节,以提出正确的答案。例如,您可以想象一个网络在每张照片的背景中记住一个模式,并使用它来将标签与对象进行匹配。它可以在训练过程中看到的所有图像上产生良好的效果,但是在新图像上会失败,

这个问题被称为过度拟合,为了避免它,我们将一些数据保留在训练过程之外,以便模型不能记住它们。然后,我们使用这些图像作为检查来确保过度拟合不会发生,因为如果我们看到它们有很好的准确性,这是一个很好的迹象表明网络不是过度拟合。通常的做法是将80%的图像放入主要的训练集中,保留10%作为训练期间的频繁验证运行,然后最终使用10%作为测试集来预测实时数据,分类器的全局表现。这些比率可以使用--testing_percentage--validation_percentage标志进行控制。一般来说,您应该能够将这些值保留为默认值,因为您通常无法找到训练的优势来调整它们。

请注意,该脚本使用图像文件名(而不是完全随机的函数)在训练,验证和测试集之间划分图像。这样做是为了确保图像不会在不同运行中的训练和测试集之间移动,因为如果用于训练模型的图像随后用于验证集,那么这可能是个问题。

您可能会注意到验证准确性在迭代中波动。这种波动的很大一部分源自这样的事实,即为每个验证准确性测量选择验证集合的随机子集。以一些训练时间的增加为代价,可以大大降低波动,通过选择--validation_batch_size=-1使用整个验证集来进行每个精度计算。

一旦训练完成,您可能会发现在测试集中检查错误分类的图像是很有有意义的。这可以通过添加标志--print_misclassified_test_images来完成。这可能有助于您了解模型中哪些类型的图像最容易混淆,哪些类别最难区分。例如,您可能会发现某个特定类别的某个子类型或某种不寻常的照片角度特别难以识别,这可能会鼓励您添加更多该子类型的训练图像。通常,检查错误分类的图像还可能会指出输入数据集中的错误,如错误标记,低质量或模糊的图像。但是,通常应避免在测试集中修正个别错误,因为它们可能仅仅反映(更大)训练集中的更普遍的问题。

其他模型结构

默认情况下,脚本使用Inception v3模型架构的预训练版本。这是一个开始的好地方,因为它提供了高精度的结果,但是如果您打算在移动设备或其他资源受限的环境中部署模型,则可能需要对小文件大小或更快的速度进行一点精确性权衡。为了解决这个问题,rebin.py脚本Mobilenet架构上支持32种不同的变体。

这些与Inception v3相比精确度稍低,但可以生成更小的文件大小(小于1兆字节),并且运行速度可以快很多倍。要用这些模型之一进行训练,请输入--architecture标志,例如:

代码语言:javascript
复制
python tensorflow/examples/image_retraining/retrain.py \
    --image_dir ~/flower_photos --architecture mobilenet_0.25_128_quantized

这将创建一个941KB的模型文件/tmp/output_graph.pb,其中全部Mobilenet的参数的25%,采用128x128大小的输入图像,并在磁盘上将其权重量化为8位。您可以选择'1.0','0.75','0.50'或'0.25'来控制权重参数的数量,因此文件大小(以及某种程度上的速度),'224','192',' 160'或128'作为输入图像尺寸,较小的尺寸可以提供更快的速度,并在结尾处有一个可选的'_quantized'来指示文件是否应包含8位或32位浮点权重。

速度和尺寸的优势当然会损失精度,但对于许多目的来说,这并不重要。他们也可以通过改进的训练数据有所抵消。例如,即使使用上述的0.25 / 128 /量化图,变形训练也可以使花数据集的准确度达到80%以上。

如果您打算在label_image或您自己的程序中使用Mobilenet模型,则需要将转换为浮点范围的指定大小的图像导入到“输入”张量中。通常,24位图像的范围是0,255,您必须将它们转换为模型预期的-1,1浮点范围(image - 128.)/128.

扫码关注腾讯云开发者

领取腾讯云代金券

http://www.vxiaotou.com