首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Custom Data Readers(自定义数据读取器)

先决条件:

  • 熟悉C ++。

我们将支持文件格式的任务分成两部分:

  • 文件格式:我们使用Reader Op 从文件中读取记录(可以是任何字符串)。
  • 记录格式:我们使用解码器或解析Ops将一个字符串记录转换为TensorFlow可用的张量。

为文件格式编写Reader

Reader是从文件读取记录的东西。TensorFlow中已经内置了一些Reader Ops的例子:

你可以看到这些都暴露了相同的接口,唯一的区别在于它们的构造函数。最重要的方法是read。它需要一个队列参数,这是它获取文件名以从需要时读取的文件名(例如,当readop首次运行时,或前read一次从文件读取最后一个记录时)。它产生两个标量张量:一个字符串键和一个字符串值。

要创建一个新的读者SomeReader,你需要:

1. 在C ++中,定义一个tensorflow::ReaderBase被调用的子类SomeReader

2. 在C ++中,用名称注册一个新的读取器操作系统和内核"SomeReader"

3. 在Python中,定义一个tf.ReaderBase被调用的子类SomeReader

你可以把所有的C ++代码放在一个文件tensorflow/core/user_ops/some_reader_op.cc中。读取文件的代码将存放在C ++ ReaderBase类的后代中,C ++ 类定义在后者中tensorflow/core/kernels/reader_base.h。您将需要实施以下方法:

  • OnWorkStartedLocked:打开下一个文件
  • ReadLocked:读取记录或报告EOF /错误
  • OnWorkFinishedLocked:关闭当前文件,并
  • ResetLocked:在例如错误之后得到干净的平板

这些方法的名称以“Locked”结尾,因为ReaderBase在调用这些方法之前确保获得互斥体,所以您通常不必担心线程安全性(尽管只保护类的成员,而不是全局状态) 。

对于OnWorkStartedLocked要打开的文件的名称是该current_work()方法返回的值。ReadLocked有这样的签名:

代码语言:javascript
复制
Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)

如果ReadLocked成功从文件中读取记录,则应填写:

  • *key:带有记录的标识符,人可以用来再次查找该记录。你可以包含文件名current_work(),并附加一个记录号码或其他。
  • *value:与记录的内容。
  • *produced:设为true

如果您点击文件末尾(EOF),请设置*at_endtrue。无论哪种情况,都会返回Status::OK()。如果出现错误,只需使用其中一个辅助函数即可返回它,tensorflow/core/lib/core/errors.h而无需修改任何参数。

接下来,您将创建实际的Reader操作。如果您熟悉添加操作方法,这将有所帮助。主要步骤是:

  • 注册操作。
  • 定义并注册一个OpKernel

要注册该操作,您将使用在中REGISTER_OP定义的呼叫tensorflow/core/framework/op.h。读者操作系统从不接受任何输入,并且始终只有一个带有类型的输出resource。他们应该有字符串containershared_nameattrs。您可以选择定义额外的attrs进行配置或在文档中包含一个Doc。例如,请参阅tensorflow/core/ops/io_ops.cc:例如:

代码语言:javascript
复制
#include "tensorflow/core/framework/op.h"

REGISTER_OP("TextLineReader")
    .Output("reader_handle: resource")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape)
    .Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");

要定义一个OpKernel,读者可以使用降序的快捷方式ReaderOpKernel,定义tensorflow/core/framework/reader_op_kernel.h和实现调用的构造函数SetReaderFactory。定义你的课程后,你需要使用注册REGISTER_KERNEL_BUILDER(...)。没有attrs的例子:

代码语言:javascript
复制
#include "tensorflow/core/framework/reader_op_kernel.h"

class TFRecordReaderOp : public ReaderOpKernel {
 public:
  explicit TFRecordReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    Env* env = context->env();
    SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
  }
};

REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
                        TFRecordReaderOp);

有attrs的一个例子:

代码语言:javascript
复制
#include "tensorflow/core/framework/reader_op_kernel.h"

class TextLineReaderOp : public ReaderOpKernel {
 public:
  explicit TextLineReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    int skip_header_lines = -1;
    OP_REQUIRES_OK(context,
                   context->GetAttr("skip_header_lines", &skip_header_lines));
    OP_REQUIRES(context, skip_header_lines >= 0,
                errors::InvalidArgument("skip_header_lines must be >= 0 not ",
                                        skip_header_lines));
    Env* env = context->env();
    SetReaderFactory([this, skip_header_lines, env]() {
      return new TextLineReader(name(), skip_header_lines, env);
    });
  }
};

REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
                        TextLineReaderOp);

最后一步是添加Python包装器。你可以通过编译一个动态库来实现,或者如果你是从源代码构建TensorFlow,添加到user_ops.py。对于后者,您将导入tensorflow.python.ops.io_opstensorflow/python/user_ops/user_ops.py添加的后裔io_ops.ReaderBase

代码语言:javascript
复制
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops

class SomeReader(io_ops.ReaderBase):

    def __init__(self, name=None):
        rr = gen_user_ops.some_reader(name=name)
        super(SomeReader, self).__init__(rr)

ops.NotDifferentiable("SomeReader")

你可以看到一些例子tensorflow/python/ops/io_ops.py

为记录格式编写操作

通常这是一个普通的操作,它将标量字符串记录作为输入,因此按照说明添加操作。您可以选择使用标量字符串键作为输入,并将其包含在报告格式不正确的数据的错误消息中。这样用户可以更轻松地追踪坏数据的来源。

可用于解码记录的Ops示例:

请注意,使用多个Ops来解码特定的记录格式会很有用。例如,可能必须保存为一个字符串的图像一个tf.train.Example协议缓冲器。根据该图像的格式,你可能会采取相应的输出从tf.parse_single_exampleOP和呼叫tf.image.decode_jpegtf.image.decode_pngtf.decode_raw。采用输出tf.decode_raw和使用tf.slice以及tf.reshape提取碎片是很常见的。

扫码关注腾讯云开发者

领取腾讯云代金券

http://www.vxiaotou.com