Operation Semantics(操作语义)
以下描述了ComputationBuilder
界面中定义的操作的语义。通常,这些操作一对一映射到RPC接口中定义的操作xla_data.proto
。
关于术语的说明:XLA处理的通用数据类型是一个保存某种统一类型元素(例如32位浮点数)的N维数组。在整个文档中,数组用于表示任意维数组。为了方便起见,特例中有更具体和熟悉的名称; 例如矢量是一维数组,矩阵是二维数组。
Broadcast
通过复制数组中的数据向数组中添加维度。
Broadcast(operand, broadcast_sizes)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | The array to duplicate |
broadcast_sizes | ArraySlice<int64> | The sizes of the new dimensions |
新维度插入在左侧,即如果broadcast_sizes
具有值{a0, ..., aN}
并且操作数形状具有尺寸,{b0, ..., bM}
则输出的形状具有维度{a0, ..., aN, b0, ..., bM}
。
新维度索引到操作数的副本中,即
output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
例如,如果operand
是f32
具有值的标量2.0f
,并且broadcast_sizes
是{2, 3}
,则结果将是具有形状的数组,结果中的f32[2, 3]
所有值都将是2.0f
。
调用
调用具有给定参数的计算。
Call(computation, args...)
参数 | 类型 | 含义 |
---|---|---|
computation | Computation | computation of type T_0, T_1, ..., T_N -> S with N parameters of arbitrary type |
args | sequence of N ComputationDataHandles | N arguments of arbitrary type |
参数的 arity
和类型必须与计算的参量相匹配。它可以没有参数。
Clamp
将操作数限制在最小值和最大值之间的范围内。
Clamp(computation, args...)
参数 | 类型 | 含义 |
---|---|---|
computation | Computation | computation of type T_0, T_1, ..., T_N -> S with N parameters of arbitrary type |
operand | ComputationDataHandle | array of type T |
min | ComputationDataHandle | array of type T |
max | ComputationDataHandle | array of type T |
给定操作数以及最小值和最大值,如果操作数处于最小值和最大值之间的范围内,则返回操作数;否则返回操作数低于此范围时的最小值或操作数高于此范围时的最大值。那就是,clamp(x, a, b) = max(min(x, a), b)
。
所有三个阵列必须是相同的形状。或者,作为广播的限制形式,min
和/或max
可以是类型的标量T
。
min
和max
实例:
let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(operand, min, max) = s32[3]{0, 5, 6};
Collapse
将数组的维度折叠为一个维度。
Collapse(operand, dimensions)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | array of type T |
dimensions | int64 vector | in-order, consecutive subset of T's dimensions. |
折叠用一个维度替换操作数维度的给定子集。输入参数是T类型的任意数组和维度索引的编译时常量向量。维度索引必须是有序(从低到高维数),是T维连续的子集。因此,{0,1,2},{0,1}或{1,2}都是有效的维度集合,但{1,0}或{0,2}却不是。它们被替换为一个新的维度,它们在维度序列中的位置与它们替换的位置相同,新的维度大小等于原始维度大小的乘积。最小维度数dimensions
是循环嵌套中最慢变化的维度(最主要),其折叠这些维度,并且最高维度数量变化最快(最小)。看到了tf.reshape
如果需要更多的常规折叠顺序,则可以使用它
例如,让v是一个由24个元素组成的数组:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17}},
{ {20, 21, 22}, {25, 26, 27}},
{ {30, 31, 32}, {35, 36, 37}},
{ {40, 41, 42}, {45, 46, 47}}};
// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};
// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47}};
// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47}};
Concatenate
连接从多个数组操作数组成一个数组。该数组与每个输入数组操作数具有相同的级别(它们的级别必须相同),并按照它们指定的顺序包含参数。
Concatenate(operands..., dimension)
参数 | 类型 | 语义 |
---|---|---|
operands | N ComputationDataHandle的序列 | 尺寸为L0,L1,...的N型数组T需要N> = 1。 |
dimension | Int64 | 间隔[0,N)中的一个值,用于命名要在操作数之间连接的维度。 |
除了dimension
所有尺寸必须相同。这是因为XLA不支持“不规则”数组请注意,rank-0值不能连接(因为不可能命名串联发生的维度)。
一维示例:
Concat({ {2, 3}, {4, 5}, {6, 7}}, 0)
>>> {2, 3, 4, 5, 6, 7}
二维示例:
let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}
图:
ConvertElementType
与static_cast
C ++中的元素类似,执行从数据形状到目标形状的按元素转换操作。尺寸必须匹配,并且转换是元素明智的; 例如s32
元件成为f32
经由元件s32
-到- f32
转换例程。
ConvertElementType(operand, new_element_type)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | array of type T with dims D |
new_element_type | PrimitiveType | type U |
如果操作数和目标形状的维度不匹配,或者请求了无效的转换(例如到/从一个元组),则会产生错误。
诸如T=s32
到U=f32
的转换将执行规范化的int-to-float转换例程,例如round-to-nearest-even。
注意:精确的浮动到整数和反向转换目前没有指定,但可能成为未来转换操作的附加参数。并非所有可能的转换都已针对所有目标实施。
let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}
Conv (卷积)
作为ConvWithGeneralPadding,但是填充以简短方式指定为SAME或VALID。SAME填充填充输入(lhs
)为零,以便输出具有与输入相同的形状,但不考虑跨度。VALID填充只是表示没有填充。
ConvWithGeneralPadding (卷积)
计算神经网络中使用的类型的卷积。这里,卷积可以被认为是移动穿过n维基本区域的n维窗口,并且针对窗口的每个可能的位置执行计算。
参数 | 类型 | 含义 |
---|---|---|
lhs | ComputationDataHandle | rank n+2 array of inputs |
rhs | ComputationDataHandle | rank n+2 array of kernel weights |
window_strides | ArraySlice<int64> | n-d array of kernel strides |
padding | ArraySlice<pair<int64, int64>> | n-d array of (low, high) padding |
lhs_dilation | ArraySlice<int64> | n-d lhs dilation factor array |
rhs_dilation | ArraySlice<int64> | n-d rhs dilation factor array |
设n是空间维数。所述lhs
参数是一个等级n + 2阵列描述底部区域。这被称为输入,尽管当然rhs也是输入。在神经网络中,这些是输入激活。n + 2维度按以下顺序排列:
batch
:这个维度中的每个坐标代表一个独立的输入,对其进行卷积运算。
z/depth/features
:基础区域中的每个(y,x)位置都有一个与其关联的向量,它将进入此维度。
spatial_dims
:描述n定义窗口移动的基本区域的空间维度。
所述rhs
参数是一个等级n + 2阵列描述卷积滤波器/内核/窗口。尺寸按以下顺序排列:
output-z
:z输出的维度。
input-z
:此维度的大小应等于z lhs中维度的大小。
spatial_dims
: 描述n定义穿过基础区域移动的nd窗口的空间尺寸。
window_strides
参数指定空间维度中的卷积窗口的步幅。例如,如果第一个空间维度的步幅为3,那么窗口只能放置在第一个空间索引可以被3整除的坐标上。
padding
参数指定要应用于基础区域的零填充量。填充量可以是负值 - 负填充的绝对值表示在卷积之前要从指定维度移除的元素的数量。padding[0]
指定维度的填充y
并padding[1]
指定维度的填充x
。每一对都有低填充作为第一个元素,高填充作为第二个元素。在较低索引的方向上应用低填充,而在较高索引的方向上应用高填充。例如,如果padding[1]
是(2,3)
那么在第二个空间维度中,左侧会有2个零填充,右侧会有3个零。使用填充相当于lhs
在卷积之前将这些相同的零值插入输入()。
lhs_dilation
和rhs_dilation
参数指定扩张因子被分别施加到LHS
和rhs
,在每个空间维度。如果空间维度中的膨胀因子为d,则在该维度中的每个条目之间隐含地放置d-1个孔,从而增加该数组的大小。这些孔填充了一个无操作值,这对于卷积来说意味着零。
rhs的扩张也被称为无限卷积。有关更多详细信息,请参阅tf.nn.atrous_conv2d
。lhs的扩张也称为解卷积。
输出形状具有这些维度,按此顺序排列:
batch
:与batch
输入(lhs
)上的大小相同。
z
:与output-z
内核(rhs)上的大小相同。
spatial_dims
: 卷积窗口每个有效位置的一个值。
卷积窗口的有效位置由填充后的基础区域的步幅和大小决定。
来形容卷积,考虑二维卷积,并挑选一些固定的batch
,z
,y
,x
坐标输出。然后(y,x)
是基部区域内窗口角落的位置(例如,左上角,取决于您如何解释空间尺寸)。我们现在有一个2d窗口,取自底部区域,每个2d点与1d矢量相关联,所以我们得到一个3d盒子。从卷积核心,因为我们固定输出坐标z
,我们也有一个3d盒子。这两个盒子具有相同的维度,所以我们可以将两个盒子之间的元素明智产品的总和(类似于点积)。那是产值。
注意,如果output-z
是实例5,则窗口的每个位置在输出中产生5个值到输出的z
维度中。这些值在使用卷积核的哪个部分不同 - 每个output-z
坐标都有一个单独的3d值框。所以你可以把它看作5个独立的卷积,每个卷积使用不同的滤波器。
这是用于填充和跨步的2d卷积的伪代码:
for (b, oz, oy, ox) { // output coordinates
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
}
CrossReplicaSum
计算副本之间的总和。
CrossReplicaSum(operand)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | Array to sum across replicas. |
输出形状与输入形状相同。例如,如果有两个副本和操作数具有值(1.0, 2.5)
和(3.0, 5.1)
分别在两个复制品,然后从该运算的输出值将是(4.0, 7.6)
在两个副本。
计算CrossReplicaSum的结果需要每个副本都有一个输入,因此如果一个副本比另一个副本执行更多次的CrossReplicaSum节点,则前一个副本将永远等待。由于副本全部运行相同的程序,因此没有太多方法可以实现,但是当while循环的条件取决于来自进给的数据并且被infed的数据导致while循环迭代多次时一个副本比另一个副本。
CustomCall
在计算中调用用户提供的函数。
CustomCall(target_name, args..., shape)
参数 | 类型 | 语义 |
---|---|---|
TARGET_NAME | 串 | 功能的名称。将发出一个调用指令,以此符号名称为目标。 |
ARGS | N ComputationDataHandles序列 | N个任意类型的参数,将被传递给该函数。 |
shape | shape | 输出函数的形状 |
函数签名是相同的,不管arg的arity或类型如何:
extern "C" void target_name(void* out, void** in);
例如,如果使用CustomCall,如下所示:
let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60}};
CustomCall("myfunc", {x, y}, f32[3x3])
以下是一个实现的例子myfunc
:
extern "C" void myfunc(void* out, void** in) {
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
}
用户提供的功能不得有副作用,其执行必须是幂等的。
注意:用户提供的函数的不透明特性限制了编译器的优化机会。尽可能地用本地XLA操作符来表达你的计算;只有使用CustomCall作为最后的手段。
Dot
Dot(lhs, rhs)
参数 | 类型 | 语义 |
---|---|---|
LHS | ComputationDataHandle | T型阵列 |
RHS | ComputationDataHandle | T型阵列 |
这个操作的确切语义取决于操作数的等级:
输入 | 产量 | 语义 |
---|---|---|
矢量n点矢量n | 纯量 | 矢量点产品 |
矩阵mxk点向量k | 矢量m | 矩阵向量乘法 |
矩阵mxk点阵kxn | 矩阵mxn | 矩阵 - 矩阵乘法 |
该操作执行产品总和的上一维lhs
和上一维的维度rhs
。这些是“合约”的维度。合约的尺寸lhs
和尺寸rhs
必须相同。实际上,它可用于执行矢量之间的点积,矢量/矩阵乘法或矩阵/矩阵乘法。
Element-wise 二进制算术运算
支持一组元素的二进制算术运算。
Op(lhs, rhs)
其中Op
是下列之一
Add
(加法), Sub
(减法), Mul
(乘), Div
(除法), Rem
(余数), Max
(最大), Min
(最小), LogicalAnd
(逻辑AND),或LogicalOr
(逻辑或)。
参数 | 类型 | 语义 |
---|---|---|
LHS | ComputationDataHandle | 左侧操作数:T型数组 |
RHS | ComputationDataHandle | 右侧操作数:类型T的数组 |
参数的形状必须相似或兼容。查看广播文档,了解它对于形状的兼容性意味着什么。操作的结果具有广播两个输入阵列的结果的形状。在这个变体中,不支持不同级别的数组之间的操作,除非其中一个操作数是标量。
如果Op
是Rem
,结果的符号取自分红,并且结果的绝对值始终小于除数的绝对值。
对于这些操作,存在具有不同等级广播支持的备选变体:
Op(lhs, rhs, broadcast_dimensions)
Op
是上面相同。操作的这种变体应该用于不同级别数组之间的算术运算(例如向矢量添加矩阵)。
附加broadcast_dimensions
操作数是用于将较低级操作数的级别扩展到较高级操作数的级别的整数片段。broadcast_dimensions
将较低等级形状的尺寸映射到较高等级形状的尺寸。展开后的形状的未映射尺寸用尺寸为1的尺寸填充。简并维度广播然后沿着这些退化维度广播形状以均衡两个操作数的形状。广播页面详细描述了语义。
Element-wise比较操作
支持一组标准的基于元素的二进制比较操作。请注意,标准IEEE 754浮点比较语义适用于比较浮点类型。
Op(lhs, rhs)
其中Op
是的一个Eq
(等于), Ne
(不等于到), Ge
(大于或-等于-比), Gt
(大于), Le
(少-或等于-比), Le
(小于)。
参数 | 类型 | 语义 |
---|---|---|
lhs | ComputationDataHandle | left-hand-side operand: array of type T |
rhs | ComputationDataHandle | right-hand-side operand: array of type T |
参数的形状必须相似或兼容。请参阅广播文档,了解它对形状的兼容性意味着什么。操作的结果具有广播具有元素类型的两个输入数组的结果的形状PRED
。在这个变体中,不支持不同级别的数组之间的操作,除非其中一个操作数是标量。
对于这些操作,存在具有不同等级广播支持的备选变体:
Op(lhs, rhs, broadcast_dimensions)
Op
与上面相同。操作的这种变体应该用于不同级别的数组之间的比较操作(例如向矢量添加矩阵)。
附加broadcast_dimensions
操作数是指定用于广播操作数的维度的整数片段。广播页面详细描述了语义。
Element-wise 一元函数
ComputationBuilder支持这些基于元素的一元函数:
Abs(operand)
Element-wise abs x -> |x|
.
Ceil(operand)
Element-wise ceil x -> ?x?
.
Cos(operand)
Element-wise cosine x -> cos(x)
.
Exp(operand)
Element-wise natural exponential x -> e^x
.
Floor(operand)
Element-wise floor x -> ?x?
.
IsFinite(operand)
测试每个元素operand
是否有限,即不是正的或负的无穷大,而不是NaN
。返回一组具有PRED
与输入相同形状的值,其中每个元素true
当且仅当相应的输入元素是有限的。
Log(operand)
Element-wise natural logarithm x -> ln(x)
.
LogicalNot(operand)
Element-wise logical not x -> !(x)
.
Neg(operand)
Element-wise negation x -> -x
.
Sign(operand)
Element-wise sign operation x -> sgn(x)
where
$$\text{sgn}(x) = \begin{cases} -1 & x < 0\ 0 & x = 0\ 1 & x > 0 \end{cases}$$
使用元素类型的比较运算符operand
。
Tanh(operand)
单元双曲正切x -> tanh(x)
。
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | The operand to the function |
该函数应用于operand
数组中的每个元素,从而生成具有相同形状的数组。允许operand
成为标量(等级0)。
BatchNormTraining
警告:尚未在GPU后端实现。
在批量和空间维度上规范阵列。
BatchNormTraining(operand, scale, offset, epsilon, feature_index)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | n dimensional array to be normalized |
scale | ComputationDataHandle | 1 dimensional array ((\gamma)) |
offset | ComputationDataHandle | 1 dimensional array ((\beta\ ) |
epsilon | float | Epsilon value ((\epsilon)) |
feature_index | int64 | Index to feature dimension in operand |
对于要素维度中的每个要素(feature_index
要素维度是operand
中的索引),该操作将计算所有其他维度的均值和方差,并使用均值和方差对operand
中的每个要素进行归一化。如果feature_index
传递无效,则会产生错误。
operand
(x)中的每个批处理的算法如下,其中包含m
具有w
和h
作为空间维度的大小的元素(假设operand
是4维数组):
- 计算
l
特征维中每个要素的批均值(\ mu_l):(\ mu_l = \ frac {1} {mwh} \ sum_ {i = 1} ^ m \ sum_ {j = 1} ^ w \ sum_ {k = 1 } ^ h x_ {ijkl})
- 计算批量方差(\ sigma ^ 2_1):(\ sigma ^ 2_l = \ frac {1} {mwh} \ sum_ {i = 1} ^ m \ sum_ {j = 1} ^ w \ sum_ {k = 1} ^ h(x_ {ijkl} - \ mu_l)^ 2)
- 归一化,放缩和位移 :(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt2{\sigma^2_l+\epsilon}}+\beta_l)
ε值通常是一个小数值,以避免被零除错误。
输出类型是三个ComputationDataHandles的元组:
输出 | 类型 | 含义 |
---|---|---|
output | ComputationDataHandle | n dimensional array with the same shape as input operand (y) |
batch_mean | ComputationDataHandle | 1 dimensional array ((\mu)) |
batch_var | ComputationDataHandle | 1 dimensional array ((\sigma^2)) |
batch_mean
和batch_var
,在批处理时,使用上面的公式计算。
BatchNormInference
警告:尚未实现。
空间阵列中批处理正则化数组
BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | n dimensional array to be normalized |
scale | ComputationDataHandle | 1 dimensional array |
offset | ComputationDataHandle | 1 dimensional array |
mean | ComputationDataHandle | 1 dimensional array |
variance | ComputationDataHandle | 1 dimensional array |
epsilon | float | Epsilon value |
feature_index | int64 | Index to feature dimension in operand |
对于要素维度中的每个要素(feature_index
是要素维度中的索引operand
),该操作将计算所有其他维度的均值和方差,并使用均值和方差对其中的每个要素进行归一化operand
。如果feature_index
传递一个无效,则会产生错误。
BatchNormInference
相当于BatchNormTraining
不需要计算mean
和variance
每个批次的调用。它使用输入mean
而variance
不是估计值。这个操作的目的是减少推断的延迟,因此也就是名称BatchNormInference
。
输出是一个与输入operand
形状相同的尺寸标准化阵列
BatchNormGrad
警告:尚未实现。
计算批量标准的梯度。
BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)
参数 | 类型 | 含义 |
---|---|---|
operand | ComputationDataHandle | n dimensional array to be normalized (x) |
scale | ComputationDataHandle | 1 dimensional array ((\gamma)) |
mean | ComputationDataHandle | 1 dimensional array ((\mu)) |
variance | ComputationDataHandle | 1 dimensional array ((\sigma^2)) |
grad_output | ComputationDataHandle | Gradients passed to BatchNormTraining (( \nabla y)) |
epsilon | float | Epsilon value ((\epsilon)) |
feature_index | int64 | Index to feature dimension in operand |
对于在特征维度中的每个特征(feature_index
为在特征维度中的索引operand
),则运算来计算梯度相对于operand
,offset
并scale
在所有其它尺寸。如果feature_index
传递无效,则会产生错误。
三个梯度由以下公式定义:
( \nabla x = \nabla y * \gamma * \sqrt{\sigma^2+\epsilon} )
( \nabla \gamma = sum(\nabla y * (x - \mu) * \sqrt{\sigma^2 + \epsilon}) )
( \nabla \beta = sum(\nabla y) )
输入mean
和variance
表示跨越批处理和空间维度的矩值。
输出类型是三个ComputationDataHandles的元组:
Outputs | Type | Semantics |
---|---|---|
grad_operand | ComputationDataHandle | gradient with respect to input |
operand | grad_offset | ComputationDataHandle |
grad_scale | ComputationDataHandle | gradient with respect to input scale |
GetTupleElement
Indexes into a tuple with a compile-time-constant value.
该值必须是编译时常数,以便形状推断可以确定结果值的类型。
这与C ++ std::get<int N>(t)
类似:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
参见 tf.tuple
.
Infeed
Infeed(shape)
参数 | 类型 | 含义 |
---|---|---|
shape | Shape | Shape of the data read from the Infeed interface. The layout field of the shape must be set to match the layout of the data sent to the device; otherwise its behavior is undefined. |
从设备的隐式Infeed流接口读取单个数据项,将数据解释为给定形状及其布局,并返回一个ComputationDataHandle
数据。计算中允许使用多个进给操作,但在进刀操作中必须有一个总顺序。例如,以下代码中的两个Infeeds具有全部顺序,因为while循环之间存在依赖关系。如果没有全部命令,编译器会发出错误。
result1 = while (condition, init = init_value) {
Infeed(shape)
}
result2 = while (condition, init = result1) {
Infeed(shape)
}
不支持嵌套的元组形状。对于空元组形状,Infeed 操作实际上是一个 nop,并且不从设备的 Infeed 读取任何数据。
注意:我们计划在没有全部命令的情况下允许多个操作,在这种情况下,编译器将提供关于进料操作在编译程序中如何序列化的信息。
地图(Map)
Map(operands..., computation)
参数 | 类型 | 语义 |
---|---|---|
操作数 | N ComputationDataHandles 的序列 | N 个类型为T_0..T_ {N-1}的数组 |
计算 | 计算 | T 类型的 N 个参数和任意类型的M的类型 T_0,T_1,...,T_ {N + M -1} - > S的计算 |
尺寸 | int64 数组 | 地图尺寸数组 |
static_operands | M ComputationDataHandles 的序列 | M 个任意类型的数组 |
对给定operands
数组应用标量函数,生成一个具有相同维的数组,其中每个元素是映射函数的结果,应用于输入数组中的相应元素,并将其static_operands
作为附加输入computation
。
映射函数是一个任意计算,其限制条件是它有N个标量类型输入T
和一个带有类型的单个输出S
。输出与操作数具有相同的尺寸,只是元素类型T被替换为S.
例如:Map(op1, op2, op3, computation, par1)
映射elem_out <- computation(elem1, elem2, elem3, par1)
输入数组中的每个(多维)索引以生成输出数组。
Pad
另见ComputationBuilder::Pad
。
Pad(operand, padding_value, padding_config)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | T型阵列 |
padding_value | ComputationDataHandle | T型标量填充添加的填充 |
padding_config | PaddingConfig | 两边(低,高)和每个维度的元素之间的填充量 |
通过填充数组以及给定operand
数组之间的元素来扩展给定的数组padding_value
。padding_config
指定每个维度的边缘填充量和内部填充量。
PaddingConfig
是PaddingConfigDimension
的代替品,其包含用于每个维度的三个字段:edge_padding_low
,edge_padding_high
,和interior_padding
。edge_padding_low
并且分别edge_padding_high
指定在每个维度的低端(紧邻索引0)和高端(紧邻最高索引)添加的填充量。边缘填充的数量可以是负数 - 负填充的绝对值表示要从指定的维度移除的元素的数量。interior_padding
指定每个维度中任何两个元素之间添加的填充量。内部填充在逻辑上发生在边缘填充之前,所以在负边缘填充的情况下,元素从内部填充的操作数中删除。如果边缘填充对全部为(0,0)且内部填充值全为0,则此操作为空操作。下图显示了二维数组的不同值edge_padding
和interior_padding
值的示例。
减少
将缩减函数应用于数组。
Reduce(operand, init_value, computation, dimensions)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | T型阵列 |
init_value | ComputationDataHandle | T型标量 |
计算 | 计算 | 计算类型T,T - > T |
尺寸 | int64数组 | 无序的维度数组来减少 |
从概念上讲,该操作将输入数组中的一个或多个维度缩减为标量。结果数组的排名是rank(operand) - len(dimensions)
。init_value
是每次减少时使用的初始值,并且如果后端选择这样做,也可以在计算期间的任何位置插入。所以在大多数情况下init_value
应该是缩减函数的标识(例如,0表示加法)。
简化函数的评估顺序是任意的,并且可能是非确定性的。因此,减少函数不应该过于敏感重新关联。
一些减法函数,如加法,对于浮点数不是严格相关的。但是,如果数据的范围有限,则浮点加法足够接近大多数实际应用的关联。然而,可以设想一些完全不联合的减少,并且这些会在XLA减少中产生不正确或不可预测的结果。
作为一个例子,当在一维数组中减少一个一维数组10,11,12,13,使用减函数f
(这是computation
),那么可以计算为
f(10, f(11, f(12, f(init_value, 13)))
但也有很多其他的可能性,例如
f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(13, init_value))))
下面是一个粗略的伪代码示例, 说明如何实现缩减, 使用求和作为初始值为0的缩减计算。
result_shape <- remove all dims in dimensions from operand_shape
# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
# Initialize this result element
result[r0, r1...] <- 0
# Iterate over all the reduction dimensions
for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
# Increment the result element with the value of the operand's element.
# The index of the operand's element is constructed from all ri's and di's
# in the right order (by construction ri's and di's together index over the
# whole operand shape).
result[r0, r1...] += operand[ri... di]
这是一个减少二维数组(矩阵)的例子。形状具有等级2,尺寸2的尺寸0和尺寸3的尺寸1:
使用“添加”功能缩小尺寸0或1的结果:
请注意,两个缩减结果都是一维数组。为了视觉方便,该图显示了一列作为列,另一列作为行显示。
对于更复杂的示例,这里是一个3D数组。它的等级是3,尺寸4的尺寸0,尺寸2的尺寸1和尺寸3的尺寸2。为了简单起见,将值1到6复制到尺寸0上。
与2D示例类似,我们可以减少一个维度。例如,如果我们减少维数0,则我们得到一个秩为2的数组,其中维度0上的所有值都被折叠为一个标量:
| 4 8 12 |
| 16 20 24 |
如果我们减少维数2,我们也得到一个秩为2的数组,其中所有维数为2的数据被折叠为一个标量:
| 6 15 |
| 6 15 |
| 6 15 |
| 6 15 |
请注意,输入中剩余维度之间的相对顺序将保留在输出中,但某些维度可能会分配新数字(因为排名发生变化)。
我们也可以减少多个维度。Add-reduction维度0和1产生一维数组| 20 28 36 |
。
在所有维度上减少3D数组会产生标量84
。
ReducePrecision
模拟将浮点值转换为较低精度格式(如IEEE-FP16)并恢复为原始格式的效果。低精度格式的指数和尾数比特数可以任意指定,尽管所有硬件实现都可能不支持所有比特尺寸。
ReducePrecision(operand, mantissa_bits, exponent_bits)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | 浮点型数组T. |
exponent_bits | INT32 | 低精度格式的指数位数 |
mantissa_bits | INT32 | 低精度格式的尾数比特数 |
结果是一个类型数组T
。输入值被四舍五入为可用给定数量的尾数位表示的最接近的值(使用“连接到偶数”语义),并且任何超过由指数位数指定的范围的值都被钳位为正或负无穷大。NaN
值被保留,虽然它们可能被转换为规范NaN
值。
低精度格式必须至少有一个指数位(为了区分零值和无穷大,因为它们都有一个零尾数),并且必须有一个非负数的尾数位。指数或尾数位的数量可能会超过类型的相应值T
; 那么转换的相应部分就是简单的无操作。
ReduceWindow
将缩小函数应用于输入多维数组的每个窗口中的所有元素,从而生成具有与窗口的有效位置数相同数量的元素的输出多维数组。池化层可以表示为a ReduceWindow
。
ReduceWindow(operand, init_value, computation, window_dimensions, window_strides, padding)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | 包含T型元素的N维数组。这是放置窗口的基础区域。 |
init_value | ComputationDataHandle | 减少的起始价值。详情请参阅减少。 |
计算 | 计算 | T型缩减功能,T - > T,适用于每个窗口中的所有元素 |
window_dimensions | ArraySlice <int64类型> | 窗口尺寸值的整数数组 |
window_strides | ArraySlice <int64类型> | 窗口跨度值的整数数组 |
填充 | 填充 | 窗口的填充类型(Padding \:\:kSame或Padding \:\:kValid) |
下面的代码和图显示了一个使用的例子ReduceWindow
。输入是一个尺寸为4x6的矩阵,window_dimensions和window_stride_dimensions都是2x3。
// Create a computation for the reduction (maximum).
Computation max;
{
ComputationBuilder builder(client_, "max");
auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
builder.Max(y, x);
max = builder.Build().ConsumeValueOrDie();
}
// Create a ReduceWindow computation with the max reduction computation.
ComputationBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
input, *max,
/*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
/*window_dimensions=*/{2, 3},
/*window_stride_dimensions=*/{2, 3},
Padding::kValid);
维度中的1步指定窗口在维度中的位置与相邻窗口的距离为1个元素。为了指定没有窗口相互重叠, window_stride_dimensions 应等于 window_dimensions。下图说明了两个不同步长值的使用情况。填充被应用于输入的每个维度, 并且计算与输入随填充后的尺寸相同。
简化函数的评估顺序是任意的,并且可能是非确定性的。因此,减少函数不应该过于敏感重新关联。有关Reduce
更多详细信息,请参阅关于关联性的讨论。
重塑
另请参阅ComputationBuilder::Reshape
和Collapse
操作。
将阵列的尺寸重新整形为新配置。
Reshape(operand, new_sizes)
Reshape(operand, dimensions, new_sizes)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | T型阵列 |
尺寸 | int64向量 | 订单的尺寸折叠 |
new_sizes | int64向量 | 新尺寸的尺寸矢量 |
从概念上讲,重塑首先将数组变成数据值的一维向量,然后将此向量细化为新的形状。输入参数是T类型的任意数组,维度索引的编译时常量向量,以及结果维度大小的编译时常量向量。dimension
向量中的值(如果给出的话)必须是T的所有维度的置换; 如果没有给出默认值{0, ..., rank - 1}
。尺寸的顺序dimensions
是从循环嵌套中最慢变化的尺寸(最主要)到最快变化的尺寸(最小),它将输入数组折叠为单个维度。所述new_sizes
矢量确定输出数组的大小。在0处的索引0处的值new_sizes
是维度0的大小,索引1处的值是维度1的大小,以此类推。在该产品new_size
的尺寸必须等于操作数的尺寸大小的产品。将折叠数组改进到由多维数组定义的数组中时new_sizes
,其中的维new_sizes
将按最慢变化(最主要)和最快变化(最小)排序。
例如,让v是一个由24个元素组成的数组:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17}},
{ {20, 21, 22}, {25, 26, 27}},
{ {30, 31, 32}, {35, 36, 37}},
{ {40, 41, 42}, {45, 46, 47}}};
In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
{20, 21, 22}, {25, 26, 27},
{30, 31, 32}, {35, 36, 37},
{40, 41, 42}, {45, 46, 47}};
Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
{31, 41, 12}, {22, 32, 42},
{15, 25, 35}, {45, 16, 26},
{36, 46, 17}, {27, 37, 47}};
let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
{11, 21}, {31, 41},
{12, 22}, {32, 42}},
{ {15, 25}, {35, 45},
{16, 26}, {36, 46},
{17, 27}, {37, 47}}};
作为特殊情况,重塑可以将单元素数组转换为标量,反之亦然。例如,
Reshape(f32[1x1] { {5}}, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5}};
Rev(反向)
Rev(operand, dimensions)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | T型阵列 |
尺寸 | ArraySlice <int64类型> | 尺寸扭??转 |
沿指定dimensions
方向颠倒数组operand
中元素的顺序,生成相同形状的输出数组。多维索引处的操作数数组的每个元素都存储在输出数组中的变换索引处。通过反转每个维度中的索引以反转(即,如果大小N的维度是反转维度之一,则其索引i被转换为N-1-i)来转换多维索引。
该Rev
操作的一个用途是在神经网络中的梯度计算期间沿着两个窗口维度反转卷积加权阵列。
RngBernoulli
使用伯努利分布生成的随机数构造给定形状的输出。参数需要是标量值的F32操作数,而输出形状需要元素类型U32。
RngBernoulli(mean, shape)
参数 | 类型 | 语义 |
---|---|---|
意思 | ComputationDataHandle | F32类型的标量指定生成数字的平均值 |
形状 | 形状 | U32型输出形状 |
RngNormal
使用随后生成的随机数构造给定形状的输出
$$N(\mu, \sigma)$$ $$正态分布。参数有mu
,sigma
和
输出形状必须有元素类型F32。这些参数还必须是标量值。
RngNormal(mean, sigma, shape)
参数 | 类型 | 语义 |
---|---|---|
mu | ComputationDataHandle | F32类型的标量指定生成数字的平均值 |
sigma | ComputationDataHandle | 指定生成数字的标准偏差的F32型标量 |
形状 | 形状 | 输出F32型的形状 |
RngUniform
使用随后生成的随机数构造给定形状的输出
区间$$ [a,b)$$的均匀分布。参数和输出
形状可以是F32,S32或U32,但类型必须一致。
此外,参数需要标量赋值。如果$$ b <= a $$结果
是实现定义的。
RngUniform(a, b, shape)
参数 | 类型 | 含义 |
---|---|---|
a | ComputationDataHandle | Scalar of type T specifying lower limit of interval |
b | ComputationDataHandle | Scalar of type T specifying upper limit of interval |
shape | Shape | Output shape of type T |
SelectAndScatter
这个操作可以被认为是一个复合操作,它首先ReduceWindow
在operand
数组上计算以从每个窗口中选择一个元素,然后将source
数组散布到选定元素的索引处,以构造一个与操作数数组具有相同形状的输出数组。二元select
函数用于从每个窗口中选择一个元素,并将其应用到每个窗口中,并用属性调用第一个参数的索引向量按字典顺序小于第二个参数的索引向量。如果选择了第一个参数,则select
返回该函数,true
如果选择false
了第二个参数,则返回该函数,并且该函数必须保持传递性(即,如果select(a, b)
和select(b, c)
是true
,那么select(a, c)
也是true
),以便所选元素不依赖于给定窗口遍历的元素的顺序。
该函数scatter
应用于输出数组中的每个选定索引。它需要两个标量参数:
1. 输出数组中所选索引的当前值
2. 来自source
它的分散值适用于选定的索引
它组合了这两个参数并返回一个标量值,该值用于更新输出数组中所选索引处的值。最初,输出数组的所有索引都设置为init_value
。
输出数组具有与数组相同的形状,operand
并且该source
数组必须具有与在数组上应用ReduceWindow
操作的结果相同的形状operand
。SelectAndScatter
可以用于反向传播神经网络中的汇聚层的梯度值。
SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)
Arguments | Type | Semantics |
---|---|---|
operand | ComputationDataHandle | array of type T over which the windows slide |
select | Computation | binary computation of type T, T -> PRED, to apply to all elements in each window; returns true if the first parameter is selected and returns false if the second parameter is selected |
window_dimensions | ArraySlice<int64> | array of integers for window dimension values |
window_strides | ArraySlice<int64> | array of integers for window stride values |
padding | Padding | padding type for window (Padding\:\:kSame or Padding\:\:kValid) |
source | ComputationDataHandle | array of type T with the values to scatter |
init_value | ComputationDataHandle | scalar value of type T for the initial value of the output array |
scatter | Computation | binary computation of type T, T -> T, to apply each scatter source element with its destination element |
下图显示了使用计算其参数中最大值SelectAndScatter
的select
函数的示例。请注意,当窗口重叠时,如下图(2)所示,operand
可通过不同的窗口多次选择阵列的索引。在该图中,值9的元素由两个顶部窗口(蓝色和红色)选择,二进制加法scatter
函数产生值8(2 + 6)的输出元素。
scatter
函数的评估顺序是任意的,可能是非确定性的。因此,该scatter
函数不应该过于敏感重关联。有关Reduce
更多详细信息,请参阅关于关联性的讨论。
Select
根据谓词数组的值,从两个输入数组的元素构造一个输出数组。
Select(pred, on_true, on_false)
参数 | 类型 | 含义 |
---|---|---|
pred | ComputationDataHandle | array of type PRED |
on_true | ComputationDataHandle | array of type T |
on_false | ComputationDataHandle | array of type T |
数组on_true
和on_false
必须具有相同的形状。这也是输出数组的形状。所述阵列pred
必须具有相同的维数on_true
和on_false
,与PRED
元件的类型。
对于每个元素P
的pred
,输出阵列的相应元素取自on_true
如果值P
是true
,并且从on_false
若的值P
是false
。作为广播的限制性形式,pred
可以是一种类型的标量PRED
。在这种情况下,输出阵列从全取on_true
如果pred
是true
,并且从on_false
如果pred
是false
。
非标量示例pred
:
let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
标量示例pred
:
let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
支持元组之间的选择。为此目的,元组被认为是标量类型。如果on_true
和on_false
是元组(它们必须具有相同的形状!),则pred
必须是类型的标量PRED
。
Slice
切片从输入数组中提取一个子数组。子数组与输入级别相同,并包含输入数组内边界框内的值,边界框的维度和索引作为切片操作的参数给出。
Slice(operand, start_indices, limit_indices)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | 类型T的N维数组 |
start_indices | ArraySlice <int64> | 包含每个维度的切片起始索引的N个整数列表。值必须大于或等于零。 |
limit_indices | ArraySlice <int64> | 包含每个维度的切片的结束索引(不包括)的N个整数的列表。每个值必须严格大于维度的相应start_indices值并且小于或等于维度的大小。 |
1维示例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
{2.0, 3.0}
二维示例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
Slice(b, {2, 1}, {4, 3}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
DynamicSlice
DynamicSlice以动态方式从输入数组中提取一个子数组start_indices
。传入每个维度中切片的大小size_indices
,这将指定每个维度中排他切片间隔的结束点:[start,start + size)。start_indices
必须是rank == 1,维度大小等于operand
。注意:处理超出边界切片索引(由'start_indices'的错误运行时计算生成)当前是实现定义的。目前,切片索引是通过计算模输入维度大小来防止出现数组访问的情况,但是这种行为在未来的实现中可能会发生变化。
DynamicSlice(operand, start_indices, size_indices)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | 类型T的N维数组 |
start_indices | ComputationDataHandle | 包含每个维度的切片的起始索引的N个整数的秩1数组。值必须大于或等于零。 |
size_indices | ArraySlice <int64> | 包含每个维度的切片大小的N个整数列表。每个值必须严格大于零,并且start + size必须小于或等于该维度的大小以避免包裹模维。 |
1维示例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}
DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}
二维示例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}
DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
DynamicUpdateSlice
DynamicUpdateSlice生成一个结果,该结果是输入数组的值operand
,其中update
覆盖了一个片段start_indices
。形状update
决定了更新结果的子数组的形状。形状start_indices
必须是rank == 1,维度大小等于operand
。注意:处理超出边界切片索引(由'start_indices'的错误运行时计算生成)当前是实现定义的。目前,分片索引是以模更新维度大小来计算的,以防止出现数组越界访问,但是这种行为在未来的实现中可能会改变。
DynamicUpdateSlice(operand, update, start_indices)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | 类型T的N维数组 |
更新 | ComputationDataHandle | 包含切片更新的类型T的N维数组。更新形状的每个维必须严格大于零,并且start + update必须小于每个维的操作数大小,以避免生成越界更新索引。 |
start_indices | ComputationDataHandle | 包含每个维度的切片的起始索引的N个整数的秩1数组。值必须大于或等于零。 |
1维示例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}
DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}
二维示例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0, 13.0},
{14.0, 15.0},
{16.0, 17.0} }
let s = {1, 1}
DynamicUpdateSlice(b, u, s) produces:
{ {0.0, 1.0, 2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }
Sort
对操作数中的元素进行排序。
Sort(operand)
Arguments | Type | Semantics |
---|---|---|
operand | ComputationDataHandle | The operand to sort |
Transpose
另请参阅tf.reshape
操作。
Transpose(operand)
参数 | 类型 | 语义 |
---|---|---|
操作数 | ComputationDataHandle | 转置的操作数。 |
排列 | ArraySlice <int64> | 如何排列维度。 |
用给定的置换来排列操作数维度,所以? i . 0 ≤ i < rank ? input_dimensions[permutation[i]] = output_dimensions[i]
。
这与Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions))相同。
Tuple
包含可变数量数据句柄的元组,每个元素都有自己的形状。
这与std::tuple
C ++ 类似。概念:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
元组可以通过GetTupleElement
操作被解构(访问)。
While
While(condition, body, init)
参数 | 类型 | 含义 |
---|---|---|
condition | Computation | Computation of type T -> PRED which defines the termination condition of the loop. |
body | Computation | Computation of type T -> T which defines the body of the loop. |
init | T | Initial value for the parameter of condition and body. |
按顺序执行body
直到condition
失败。除了下面列出的差异和限制之外,这与许多其他语言中的典型while循环类似。
- 一个
While
节点返回一个类型值T
,这是上一次执行的结果body
。
- 类型的形状
T
是静态确定的,并且在所有迭代中必须相同。
While
节点不允许嵌套。(有些目标将来可能会取消此限制。)
计算的T参数用init
第一次迭代中的值进行初始化,并自动更新为body
每次后续迭代中的新结果。
该While
节点的一个主要用例是实现神经网络中训练的重复执行。简化的伪代码如下所示,代表计算的图形。代码可以在中找到while_test.cc
。T
这个例子中的类型Tuple
由一个int32
迭代计数和一个vector[10]
累加器组成。对于1000次迭代,循环不断向累加器添加一个常量向量。
// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
iteration = result(0) + 1;
new_vector = result(1) + constant_vector[10];
result = {iteration, new_vector};
}
本文档系腾讯云开发者社区成员共同维护,如有问题请联系 cloudcommunity@tencent.com