Beaver.MLIR.Dialect.Shard (beaver v0.4.7)

Summary

Functions

shard.all_gather - All-gather over a device grid.

shard.all_reduce - All-reduce over a device grid.

shard.all_slice - All-slice over a device grid. This is the inverse of all-gather.

shard.all_to_all - All-to-all over a device grid.

shard.broadcast - Broadcast over a device grid.

shard.gather - Gather over a device grid.

shard.get_sharding - Get the sharding of the given tensor.

shard.grid

shard.grid_shape

shard.neighbors_linear_indices - For given grid index get the linear indices of the direct neighbor processes along the given split.

shard.process_linear_index

shard.process_multi_index

shard.recv - Send over a device grid.

shard.reduce - Reduce over a device grid.

shard.reduce_scatter - Reduce-scatter over a device grid.

shard.scatter - Scatter over a device grid.

shard.send - Send over a device grid.

shard.shard - Annotate on how a tensor is sharded across a shard.

shard.shard_shape - Get the shard shape for a given process/device.

shard.sharding - Define a sharding of a tensor.

shard.shift - Shift over a device grid.

shard.update_halo - Update halo data.

Functions

all_gather(ssa)

shard.all_gather - All-gather over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • gather_axis - Single, IndexAttr, index attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Results

  • result - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Description

Gathers along the gather_axis tensor axis.

Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
  : tensor<2x2xi8> -> tensor<2x4xi8>

Input:

                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

gather tensor
axis 1
------------>
+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

all_reduce(ssa)

shard.all_reduce - All-reduce over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • reduction - Single, Shard_ReductionKindAttr, Reduction of an iterator/grid dimension.

Operands

  • input - Single, anonymous/composite constraint, memref of any type values or ranked tensor of any type values

Results

  • result - Single, anonymous/composite constraint, memref of any type values or ranked tensor of any type values

Description

The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
  : tensor<3x4xf32> -> tensor<3x4xf64>

all_slice(ssa)

shard.all_slice - All-slice over a device grid. This is the inverse of all-gather.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • slice_axis - Single, IndexAttr, index attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Results

  • result - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Description

Slice along the slice_axis tensor axis. This operation can be thought of as the inverse of all-gather. Technically, it is not required that all processes have the same input tensor. Each process will slice a piece of its local tensor based on its in-group device index. The operation does not communicate data between devices.

Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
  : tensor<2x4xi8> -> tensor<2x2xi8>

Input:

+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Result:

gather tensor
axis 1
------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

all_to_all(ssa)

shard.all_to_all - All-to-all over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • split_axis - Single, IndexAttr, index attribute
  • concat_axis - Single, IndexAttr, index attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Results

  • result - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Description

Performs an all-to-all on tensor pieces split along split_axis. The resulting pieces are concatenated along concat_axis on ech device.

Example:

shard.grid @grid0(shape = 3)
...
%1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
  split_axis = 0 concat_axis = 0
  : tensor<3x2xi8> -> tensor<3x2xi8>

Input:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+  | split and concat along
| 11 12 | 21 22 | 31 32 |  | tensor axis 0
| 13 14 | 23 24 | 33 34 |  
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+

Result:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+

broadcast(ssa)

shard.broadcast - Broadcast over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • root - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • input - Single, AnyRankedTensor, ranked tensor of any type values
  • root_dynamic - Variadic, Index, variadic of index

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

Broadcast the tensor on root to all devices in each respective group. The operation broadcasts along grid axes grid_axes. The root device specifies the in-group multi-index that is broadcast to all other devices in the group.

Example:

shard.grid @grid0(shape = 2x2)

%1 = shard.broadcast %0 on @grid0
  grid_axes = [0]
  root = [0]
  : (tensor<2xi8>) -> tensor<2xi8>

Input:

                 +-------+-------+                   | broadcast
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                 +-------+-------+                   
device (1, 0) -> |       |       | <- device (1, 1) 
                 +-------+-------+

Output:

                 +-------+-------+
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
                 +-------+-------+
device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
                 +-------+-------+

gather(ssa)

shard.gather - Gather over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • gather_axis - Single, IndexAttr, index attribute
  • root - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values
  • root_dynamic - Variadic, Index, variadic of index

Results

  • result - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Description

Gathers on device root along the gather_axis tensor axis. root specifies the coordinates of a device along grid_axes. It uniquely identifies the root device for each device group. The result tensor on non-root devices is undefined. Using it will result in undefined behavior.

Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.gather %0 on @grid0 grid_axes = [1]
  gather_axis = 1 root = [1]
  : (tensor<2x2xi8>) -> tensor<2x4xi8>

Input:

                  gather tensor
                  axis 1
                  ------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

+-------------+
|  1  2  5  6 | <- devices (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+

Devices (0, 0) and (1, 0) have undefined result.

get_sharding(ssa)

shard.get_sharding - Get the sharding of the given tensor.

This op has support for result type inference.

Operands

  • source - Single, AnyRankedTensor, ranked tensor of any type values

Results

  • result - Single, Shard_Sharding, sharding definition

Description

This operation returns the sharding of the given tensor as a Sharding.

grid(ssa)

shard.grid

grid_shape(ssa)

shard.grid_shape

neighbors_linear_indices(ssa)

shard.neighbors_linear_indices - For given grid index get the linear indices of the direct neighbor processes along the given split.

This op has support for result type inference.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • split_axes - Single, Shard_GridAxesAttr, i16 dense array attribute

Operands

  • device - Variadic, Index, variadic of index

Results

  • neighbor_down - Single, Index, index
  • neighbor_up - Single, Index, index

Description

Example:

shard.grid @grid0(shape = 10x20x30)
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index

The above returns two indices, 633 and 693, which correspond to the index of the previous process (1, 1, 3), and the next process (1, 3, 3) along the split axis1. A negative value is returned if there is no neighbor in the respective direction along the givensplit_axes`.

process_linear_index(ssa)

shard.process_linear_index

process_multi_index(ssa)

shard.process_multi_index

recv(ssa)

shard.recv - Send over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • source - Optional, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values
  • source_dynamic - Variadic, Index, variadic of index

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

Receive from a device within a device group.

reduce(ssa)

shard.reduce - Reduce over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • reduction - Single, Shard_ReductionKindAttr, Reduction of an iterator/grid dimension.
  • root - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • input - Single, AnyRankedTensor, ranked tensor of any type values
  • root_dynamic - Variadic, Index, variadic of index

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

Reduces on device root within each device group. root specifies the coordinates of a device along grid_axes. It uniquely identifies the root device within its device group. The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
  reduction = <max> root = [2, 3]
  : (tensor<3x4xf32>) -> tensor<3x4xf64>

reduce_scatter(ssa)

shard.reduce_scatter - Reduce-scatter over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • reduction - Single, Shard_ReductionKindAttr, Reduction of an iterator/grid dimension.
  • scatter_axis - Single, IndexAttr, index attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

After the reduction, the result is scattered within each device group. The tensor is split along scatter_axis and the pieces distributed across the device group. Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
  reduction = <max> scatter_axis = 0
  : tensor<3x4xf32> -> tensor<1x4xf64>

Input:

                          device
                          (0, 1)
                             
                 +-------+-------+  | scatter tensor
device (0, 0) -> |  1  2 |  5  6 |  | axis 0
                 |  3  4 |  7  8 |  
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 |
                 | 11 12 | 15 16 |
                 +-------+-------+
                            
                          device
                          (1, 1)

Result:

+-------+
|  6  8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+

scatter(ssa)

shard.scatter - Scatter over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • scatter_axis - Single, IndexAttr, index attribute
  • root - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values
  • root_dynamic - Variadic, Index, variadic of index

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

For each device group split the input tensor on the root device along axis scatter_axis and scatter the parts across the group devices.

Example:

shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
  scatter_axis = 0
  root = [1]
  : (tensor<2x2xi8>) -> tensor<1x2xi8>

Input:

                          device
                          (0, 1)
                             
                 +-------+-------+  | scatter tensor
device (0, 0) -> |       |       |  | axis 0
                 |       |       |  
                 +-------+-------+
device (1, 0) -> |  1  2 |  5  6 |
                 |  3  4 |  7  8 |
                 +-------+-------+
                            
                          device
                          (1, 1)

Result:

                          device
                          (0, 1)
                             
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 |
                 +-------+-------+ 
device (1, 0) -> |  3  4 |  7  8 |
                 +-------+-------+
                            
                          device
                          (1, 1)

send(ssa)

shard.send - Send over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • destination - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values
  • destination_dynamic - Variadic, Index, variadic of index

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

Send from one device to another within a device group.

shard(ssa)

shard.shard - Annotate on how a tensor is sharded across a shard.

This op has support for result type inference.

Attributes

  • annotate_for_users - Optional, UnitAttr, unit attribute

Operands

  • src - Single, AnyRankedTensor, ranked tensor of any type values
  • sharding - Single, Shard_Sharding, sharding definition

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

The shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes:

  1. input: This operand represents the tensor value that needs to be annotated for sharding.

  2. sharding: This attribute is type of ShardingType, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by an shard.sharding operation.

  3. annotate_for_users: A unit attribute addressing the scenario when a tensor's sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.

Example:

func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
    %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
    ...
  }

  func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
    %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
    ...
  }

  func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
    %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
    %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
    ...
  }

  // The first shard.shard op applies to %arg0, the second shard.shard op
  // applies for the operand of op0, the third shard.shard op applies for the
  // operand of op2
  func.func @both_result_and_multi_operands_annotated(
      %arg0 : tensor<4x8xf32>) -> () {
    %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
    %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
    %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
    %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
    %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
    "op0"(%1) : ...
    "op1"(%2) : ...
    ...
  }

The following usages are undefined:

  func.func @annotate_on_same_result_with_different_sharding(
      %arg0 : tensor<4x8xf32>) -> () {
    %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
    %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
    %1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
    ...
  }

  func.func @annotate_on_same_result_same_value_with_different_sharding(
      %arg0 : tensor<4x8xf32>) -> () {
    %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
    %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
    ...
  }

  func.func @annotate_on_same_operand_with_different_sharding(
      %arg0 : tensor<4x8xf32>) -> () {
    %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
    %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
    ...
  }

  func.func @result_annotated_after_operand(
      %arg0 : tensor<4x8xf32>) -> () {
    %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
    %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
    %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
    %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
    ...
  }

shard_shape(ssa)

shard.shard_shape - Get the shard shape for a given process/device.

Attributes

  • dims - Single, DenseI64ArrayAttr, i64 dense array attribute
  • device - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • dims_dynamic - Variadic, Index, variadic of index
  • sharding - Single, Shard_Sharding, sharding definition
  • device_dynamic - Variadic, Index, variadic of index

Results

  • result - Variadic, Index, variadic of index

Description

The device/process id is a multi-index of the device/process in the shard. This operation might be used during partition when the shard shape depends on (non-constant) values used in shard.sharding.

sharding(ssa)

shard.sharding - Define a sharding of a tensor.

This op has support for result type inference.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • split_axes - Single, Shard_GridAxesArrayAttr,
  • static_sharded_dims_offsets - Single, DenseI64ArrayAttr, i64 dense array attribute
  • static_halo_sizes - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • dynamic_sharded_dims_offsets - Variadic, I64, variadic of 64-bit signless integer
  • dynamic_halo_sizes - Variadic, I64, variadic of 64-bit signless integer

Results

  • result - Single, Shard_Sharding, sharding definition

Description

The Sharding specifies how a tensor is sharded and distributed across the process shard. It is typically used in a shard.shard operation. The operation has the following attributes and operands:

  1. grid: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to a shard.grid operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array's maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor's i-th dimension is splitted along the x and y axes of the device grid.

  3. [Optional] Sizes of halos to be added for each sharded tensor dimension. halo_sizes is provided as a flattened 1d array of i64s, 2 values for each sharded dimension. halo_sizes = [1, 2] means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end. halo_sizes = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos. ? indicates dynamic halo sizes.

  4. [Optional] Offsets for each shard and sharded tensor dimension. sharded_dims_offsets is provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that position i has the exclusive prefix sum for shard i, and since only contiguous sharding is supported, its inclusive prefix sum is at position 'i+1'.

Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, sharded_dims_offsets = [0, 24, 32, 0, 20, 32] means that the first device of the device-grid will get a shard of shape 24x20x32 and the second device will get a shard of shape 8x12x32. ? indicates dynamic shard dimensions.

halo_sizes and sharded_dims_offsets are mutually exclusive.

Examples:

shard.grid @grid0(shape = 2x2x4)
shard.grid @grid1d_4(shape = 4)

// The tensor is fully replicated on @grid0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = shard.sharding @grid0 split_axes = [[]]

// The tensor is sharded on the first dimension along axis 0 of @grid0
%sharding1 = shard.sharding @grid0 split_axes = [[0]]

// Could be used for a shard.shard op
%sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>

// The tensor is sharded on its first dimension along axis 0 of @grid0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>

// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>

shift(ssa)

shard.shift - Shift over a device grid.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • grid_axes - Single, Shard_GridAxesAttr, i16 dense array attribute
  • shift_axis - Single, IndexAttr, index attribute
  • offset - Single, I64Attr, 64-bit signless integer attribute
  • rotate - Optional, UnitAttr, unit attribute

Operands

  • input - Single, AnyNon0RankedTensor, non-0-ranked.tensor of any type values

Results

  • result - Single, AnyRankedTensor, ranked tensor of any type values

Description

Within each device group shift along grid axis shift_axis by an offset offset. The result on devices that do not have a corresponding source is undefined. shift_axis must be one of grid_axes. If the rotate attribute is present, instead of a shift a rotation is done.

Example:

shard.grid @grid0(shape = 2x4)
%1 = shard.shift on @grid0 grid_axes = [1]
  shift_axis = 1 offset = 2 rotate
  : tensor<2xi8> -> tensor<2xi8>

Input:

grid axis 1
----------->

+----+----+----+----+
|  1 |  2 |  3 |  4 |
+----+----+----+----+
|  5 |  6 |  7 |  8 |
+----+----+----+----+

Result:

+----+----+----+----+
|  3 |  4 |  1 |  2 |
+----+----+----+----+
|  7 |  8 |  5 |  6 |
+----+----+----+----+

update_halo(ssa)

shard.update_halo - Update halo data.

Attributes

  • grid - Single, FlatSymbolRefAttr, flat symbol reference attribute
  • split_axes - Single, Shard_GridAxesArrayAttr,
  • static_halo_sizes - Single, DenseI64ArrayAttr, i64 dense array attribute

Operands

  • destination - Single, anonymous/composite constraint, non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values
  • halo_sizes - Variadic, I64, variadic of 64-bit signless integer

Results

  • result - Single, anonymous/composite constraint, non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values

Description

This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.

Destination is supposed to be initialized with the local data (not halos).

Assumes all devices hold tensors with same-sized halo data as specified by source_halo_sizes/static_source_halo_sizes and destination_halo_sizes/static_destination_halo_sizes in source shard and destination/result shard.

split_axes specifies for each tensor axis along which grid axes its halo data is updated.