Beaver.MLIR.Dialect.Shard (beaver v0.4.7)
Summary
Functions
shard.all_gather - All-gather over a device grid.
shard.all_reduce - All-reduce over a device grid.
shard.all_slice - All-slice over a device grid. This is the inverse of all-gather.
shard.all_to_all - All-to-all over a device grid.
shard.broadcast - Broadcast over a device grid.
shard.gather - Gather over a device grid.
shard.get_sharding - Get the sharding of the given tensor.
shard.grid
shard.grid_shape
shard.neighbors_linear_indices - For given grid index get the linear indices of the direct neighbor processes along the given split.
shard.process_linear_index
shard.process_multi_index
shard.recv - Send over a device grid.
shard.reduce - Reduce over a device grid.
shard.reduce_scatter - Reduce-scatter over a device grid.
shard.scatter - Scatter over a device grid.
shard.send - Send over a device grid.
shard.shard - Annotate on how a tensor is sharded across a shard.
shard.shard_shape - Get the shard shape for a given process/device.
shard.sharding - Define a sharding of a tensor.
shard.shift - Shift over a device grid.
shard.update_halo - Update halo data.
Functions
shard.all_gather - All-gather over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributegather_axis- Single,IndexAttr, index attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Results
result- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Description
Gathers along the gather_axis tensor axis.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>Input:
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+Result:
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
shard.all_reduce - All-reduce over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributereduction- Single,Shard_ReductionKindAttr, Reduction of an iterator/grid dimension.
Operands
input- Single, anonymous/composite constraint, memref of any type values or ranked tensor of any type values
Results
result- Single, anonymous/composite constraint, memref of any type values or ranked tensor of any type values
Description
The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.
Attributes:
reduction: Indicates the reduction method.
Example:
%1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
shard.all_slice - All-slice over a device grid. This is the inverse of all-gather.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributeslice_axis- Single,IndexAttr, index attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Results
result- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Description
Slice along the slice_axis tensor axis.
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>Input:
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+Result:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
shard.all_to_all - All-to-all over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributesplit_axis- Single,IndexAttr, index attributeconcat_axis- Single,IndexAttr, index attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Results
result- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Description
Performs an all-to-all on tensor pieces split along split_axis.
The resulting pieces are concatenated along concat_axis on ech device.
Example:
shard.grid @grid0(shape = 3)
...
%1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>Input:
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+Result:
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
shard.broadcast - Broadcast over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributeroot- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
input- Single,AnyRankedTensor, ranked tensor of any type valuesroot_dynamic- Variadic,Index, variadic of index
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
Broadcast the tensor on root to all devices in each respective group.
The operation broadcasts along grid axes grid_axes.
The root device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Example:
shard.grid @grid0(shape = 2x2)
%1 = shard.broadcast %0 on @grid0
grid_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>Input:
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | | | <- device (1, 1)
+-------+-------+Output:
+-------+-------+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+-------+-------+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+-------+-------+
shard.gather - Gather over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributegather_axis- Single,IndexAttr, index attributeroot- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type valuesroot_dynamic- Variadic,Index, variadic of index
Results
result- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Description
Gathers on device root along the gather_axis tensor axis.
root specifies the coordinates of a device along grid_axes.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.gather %0 on @grid0 grid_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>Input:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+Result:
+-------------+
| 1 2 5 6 | <- devices (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+Devices (0, 0) and (1, 0) have undefined result.
shard.get_sharding - Get the sharding of the given tensor.
This op has support for result type inference.
Operands
source- Single,AnyRankedTensor, ranked tensor of any type values
Results
result- Single,Shard_Sharding, sharding definition
Description
This operation returns the sharding of the given tensor as a Sharding.
shard.grid
shard.grid_shape
shard.neighbors_linear_indices - For given grid index get the linear indices of the direct neighbor processes along the given split.
This op has support for result type inference.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributesplit_axes- Single,Shard_GridAxesAttr, i16 dense array attribute
Operands
device- Variadic,Index, variadic of index
Results
neighbor_down- Single,Index, indexneighbor_up- Single,Index, index
Description
Example:
shard.grid @grid0(shape = 10x20x30)
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : indexThe above returns two indices, 633 and 693, which correspond to the
index of the previous process (1, 1, 3), and the next process
(1, 3, 3) along the split axis1. A negative value is returned if there is no neighbor in the respective direction along the givensplit_axes`.
shard.process_linear_index
shard.process_multi_index
shard.recv - Send over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributesource- Optional,DenseI64ArrayAttr, i64 dense array attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type valuessource_dynamic- Variadic,Index, variadic of index
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
Receive from a device within a device group.
shard.reduce - Reduce over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributereduction- Single,Shard_ReductionKindAttr, Reduction of an iterator/grid dimension.root- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
input- Single,AnyRankedTensor, ranked tensor of any type valuesroot_dynamic- Variadic,Index, variadic of index
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
Reduces on device root within each device group.
root specifies the coordinates of a device along grid_axes.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Attributes:
reduction: Indicates the reduction method.
Example:
%1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
shard.reduce_scatter - Reduce-scatter over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributereduction- Single,Shard_ReductionKindAttr, Reduction of an iterator/grid dimension.scatter_axis- Single,IndexAttr, index attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
After the reduction, the result is scattered within each device group.
The tensor is split along scatter_axis and the pieces distributed
across the device group.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
↑
device
(1, 1)Result:
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
shard.scatter - Scatter over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributescatter_axis- Single,IndexAttr, index attributeroot- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type valuesroot_dynamic- Variadic,Index, variadic of index
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
For each device group split the input tensor on the root device along
axis scatter_axis and scatter the parts across the group devices.
Example:
shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | | | | axis 0
| | | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)Result:
device
(0, 1)
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
+-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
shard.send - Send over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributedestination- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type valuesdestination_dynamic- Variadic,Index, variadic of index
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
Send from one device to another within a device group.
shard.shard - Annotate on how a tensor is sharded across a shard.
This op has support for result type inference.
Attributes
annotate_for_users- Optional,UnitAttr, unit attribute
Operands
src- Single,AnyRankedTensor, ranked tensor of any type valuessharding- Single,Shard_Sharding, sharding definition
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
The shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes:
input: This operand represents the tensor value that needs to be annotated for sharding.sharding: This attribute is type ofShardingType, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by anshard.shardingoperation.annotate_for_users: A unit attribute addressing the scenario when a tensor's sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.
Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
...
}
// The first shard.shard op applies to %arg0, the second shard.shard op
// applies for the operand of op0, the third shard.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
%sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
%2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
}The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
%1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
%1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
...
}
shard.shard_shape - Get the shard shape for a given process/device.
Attributes
dims- Single,DenseI64ArrayAttr, i64 dense array attributedevice- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
dims_dynamic- Variadic,Index, variadic of indexsharding- Single,Shard_Sharding, sharding definitiondevice_dynamic- Variadic,Index, variadic of index
Results
result- Variadic,Index, variadic of index
Description
The device/process id is a multi-index of the device/process in the shard.
This operation might be used during partition when the shard shape depends
on (non-constant) values used in shard.sharding.
shard.sharding - Define a sharding of a tensor.
This op has support for result type inference.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributesplit_axes- Single,Shard_GridAxesArrayAttr,static_sharded_dims_offsets- Single,DenseI64ArrayAttr, i64 dense array attributestatic_halo_sizes- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
dynamic_sharded_dims_offsets- Variadic,I64, variadic of 64-bit signless integerdynamic_halo_sizes- Variadic,I64, variadic of 64-bit signless integer
Results
result- Single,Shard_Sharding, sharding definition
Description
The Sharding specifies how a tensor is sharded and distributed across the
process shard. It is typically used in a shard.shard operation.
The operation has the following attributes and operands:
grid: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to ashard.gridoperation.split_axes: is an array composed of int64_t sub-arrays. The outer array's maximum size is therankof the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor's i-th dimension is splitted along the x and y axes of the device grid.[Optional] Sizes of halos to be added for each sharded tensor dimension.
halo_sizesis provided as a flattened 1d array of i64s, 2 values for each sharded dimension.halo_sizes = [1, 2]means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end.halo_sizes = [1, 2, 2, 3]defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets[1,2]halos and the seconds gets[2,3]halos.?indicates dynamic halo sizes.[Optional] Offsets for each shard and sharded tensor dimension.
sharded_dims_offsetsis provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that positionihas the exclusive prefix sum for shardi, and since only contiguous sharding is supported, its inclusive prefix sum is at position 'i+1'.
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
sharded_dims_offsets = [0, 24, 32, 0, 20, 32] means that the first device of
the device-grid will get a shard of shape 24x20x32 and the second device will get
a shard of shape 8x12x32. ? indicates dynamic shard dimensions.
halo_sizes and sharded_dims_offsets are mutually exclusive.
Examples:
shard.grid @grid0(shape = 2x2x4)
shard.grid @grid1d_4(shape = 4)
// The tensor is fully replicated on @grid0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = shard.sharding @grid0 split_axes = [[]]
// The tensor is sharded on the first dimension along axis 0 of @grid0
%sharding1 = shard.sharding @grid0 split_axes = [[0]]
// Could be used for a shard.shard op
%sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>
// The tensor is sharded on its first dimension along axis 0 of @grid0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
shard.shift - Shift over a device grid.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributegrid_axes- Single,Shard_GridAxesAttr, i16 dense array attributeshift_axis- Single,IndexAttr, index attributeoffset- Single,I64Attr, 64-bit signless integer attributerotate- Optional,UnitAttr, unit attribute
Operands
input- Single,AnyNon0RankedTensor, non-0-ranked.tensor of any type values
Results
result- Single,AnyRankedTensor, ranked tensor of any type values
Description
Within each device group shift along grid axis shift_axis by an offset
offset.
The result on devices that do not have a corresponding source is undefined.
shift_axis must be one of grid_axes.
If the rotate attribute is present,
instead of a shift a rotation is done.
Example:
shard.grid @grid0(shape = 2x4)
%1 = shard.shift on @grid0 grid_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>Input:
grid axis 1
----------->
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+Result:
+----+----+----+----+
| 3 | 4 | 1 | 2 |
+----+----+----+----+
| 7 | 8 | 5 | 6 |
+----+----+----+----+
shard.update_halo - Update halo data.
Attributes
grid- Single,FlatSymbolRefAttr, flat symbol reference attributesplit_axes- Single,Shard_GridAxesArrayAttr,static_halo_sizes- Single,DenseI64ArrayAttr, i64 dense array attribute
Operands
destination- Single, anonymous/composite constraint, non-0-ranked.memref of any type values or non-0-ranked.tensor of any type valueshalo_sizes- Variadic,I64, variadic of 64-bit signless integer
Results
result- Single, anonymous/composite constraint, non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values
Description
This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.
Destination is supposed to be initialized with the local data (not halos).
Assumes all devices hold tensors with same-sized halo data as specified
by source_halo_sizes/static_source_halo_sizes and
destination_halo_sizes/static_destination_halo_sizes in source shard
and destination/result shard.
split_axes specifies for each tensor axis along which grid axes its halo
data is updated.