PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.GroupCoordinator takes charge of all the communication operations among
the processes in the group. It manages both CPU and device
communication.
Summary
Functions
Python method GroupCoordinator._all_gather_out_place.
Python method GroupCoordinator._all_reduce_out_place.
Python method GroupCoordinator._reduce_scatter_out_place.
Python method GroupCoordinator.all_gather.
Python method GroupCoordinator.all_gatherv.
User-facing all-reduce function before we actually call the
Barrier synchronization among the group.
Broadcast the input tensor.
Broadcast the input object.
Broadcast the input object list.
Broadcast the input tensor dictionary.
Python method GroupCoordinator.combine.
Python method GroupCoordinator.create_mq_broadcaster.
Python method GroupCoordinator.create_single_reader_mq_broadcasters.
Python method GroupCoordinator.destroy.
Python method GroupCoordinator.dispatch.
NOTE: We assume that the input tensor is on the same device across
Python method GroupCoordinator.graph_capture.
Initialize self. See help(type(self)) for accurate signature.
Python method GroupCoordinator.prepare_communication_buffer_for_model.
Receives a tensor from the source rank.
Receive the input object list from the source rank.
Recv the input tensor dictionary.
Python method GroupCoordinator.reduce_scatter.
Python method GroupCoordinator.reduce_scatterv.
Sends a tensor to the destination rank in a blocking way
Send the input object list to the destination rank.
Send the input tensor dictionary.
Types
Functions
@spec _all_gather_out_place(SnakeBridge.Ref.t(), term(), integer(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator._all_gather_out_place.
Parameters
input_(term())dim(integer())
Returns
term()
@spec _all_reduce_out_place(SnakeBridge.Ref.t(), term(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator._all_reduce_out_place.
Parameters
input_(term())
Returns
term()
@spec _reduce_scatter_out_place(SnakeBridge.Ref.t(), term(), integer(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator._reduce_scatter_out_place.
Parameters
input_(term())dim(integer())
Returns
term()
@spec all_gather(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.all_gather.
Parameters
input_(term())dim(integer() default: -1)
Returns
term()
@spec all_gatherv(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.all_gatherv.
Parameters
input_(term())dim(integer() default: 0)sizes(term() default: None)
Returns
term()
@spec all_reduce(SnakeBridge.Ref.t(), term(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (self in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning a new tensor in the same op. So we always make the all-reduce operation out-of-place.
Parameters
input_(term())
Returns
term()
@spec barrier( SnakeBridge.Ref.t(), keyword() ) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Barrier synchronization among the group.
NOTE: don't use device_group here! barrier in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
Returns
term()
@spec broadcast(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Broadcast the input tensor.
NOTE: src is the local rank of the source rank.
Parameters
input_(term())src(integer() default: 0)
Returns
term()
@spec broadcast_object(SnakeBridge.Ref.t(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Broadcast the input object.
NOTE: src is the local rank of the source rank.
Parameters
obj(term() default: None)src(integer() default: 0)
Returns
term()
@spec broadcast_object_list(SnakeBridge.Ref.t(), [term()], [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Broadcast the input object list.
NOTE: src is the local rank of the source rank.
Parameters
obj_list(list(term()))src(integer() default: 0)group(term() default: None)
Returns
term()
@spec broadcast_tensor_dict(SnakeBridge.Ref.t(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Broadcast the input tensor dictionary.
NOTE: src is the local rank of the source rank.
Parameters
tensor_dict(term() default: None)src(integer() default: 0)group(term() default: None)metadata_group(term() default: None)
Returns
term()
@spec combine(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.combine.
Parameters
hidden_states(term())is_sequence_parallel(boolean() default: False)
Returns
term()
@spec create_mq_broadcaster(SnakeBridge.Ref.t(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.create_mq_broadcaster.
Parameters
writer_rank(term() default: 0)external_writer_handle(term() default: None)blocking(term() default: True)
Returns
term()
@spec create_single_reader_mq_broadcasters(SnakeBridge.Ref.t(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.create_single_reader_mq_broadcasters.
Parameters
reader_rank_in_group(term() default: 0)blocking(term() default: False)
Returns
term()
@spec destroy( SnakeBridge.Ref.t(), keyword() ) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.destroy.
Returns
term()
@spec dispatch(SnakeBridge.Ref.t(), term(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.dispatch.
Parameters
hidden_states(term())router_logits(term())is_sequence_parallel(boolean() default: False)extra_tensors(term() default: None)
Returns
term()
@spec first_rank(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec gather(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: dst is the local rank of the destination rank.
Parameters
input_(term())dst(integer() default: 0)dim(integer() default: -1)
Returns
term()
@spec graph_capture(SnakeBridge.Ref.t(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.graph_capture.
Parameters
graph_capture_context(term() default: None)
Returns
term()
@spec is_first_rank(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec is_last_rank(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec last_rank(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec new([[integer()]], integer(), term(), boolean(), [term()], keyword()) :: {:ok, SnakeBridge.Ref.t()} | {:error, Snakepit.Error.t()}
Initialize self. See help(type(self)) for accurate signature.
Parameters
group_ranks(list(list(integer())))local_rank(integer())torch_distributed_backend(term())use_device_communicator(boolean())use_message_queue_broadcaster(boolean() default: False)group_name(term() default: None)
@spec next_rank(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec prepare_communication_buffer_for_model(SnakeBridge.Ref.t(), term(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.prepare_communication_buffer_for_model.
Parameters
model(term())
Returns
term()
@spec prev_rank(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec recv(SnakeBridge.Ref.t(), term(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Receives a tensor from the source rank.
Parameters
size(term())dtype(term())src(term() default: None)
Returns
term()
@spec recv_object(SnakeBridge.Ref.t(), integer(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Receive the input object list from the source rank.
Parameters
src(integer())
Returns
term()
@spec recv_tensor_dict(SnakeBridge.Ref.t(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Recv the input tensor dictionary.
NOTE: src is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.Parameters
src(term() default: None)all_gather_group(term() | nil default: None)all_gather_tensors(term() default: None)
Returns
term()
@spec reduce_scatter(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.reduce_scatter.
Parameters
input_(term())dim(integer() default: -1)
Returns
term()
@spec reduce_scatterv(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method GroupCoordinator.reduce_scatterv.
Parameters
input_(term())dim(integer() default: -1)sizes(term() default: None)
Returns
term()
@spec send(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, nil} | {:error, Snakepit.Error.t()}
Sends a tensor to the destination rank in a blocking way
Parameters
tensor(term())dst(term() default: None)
Returns
nil
@spec send_object(SnakeBridge.Ref.t(), term(), integer(), keyword()) :: {:ok, nil} | {:error, Snakepit.Error.t()}
Send the input object list to the destination rank.
Parameters
obj(term())dst(integer())
Returns
nil
@spec send_tensor_dict( SnakeBridge.Ref.t(), %{optional(String.t()) => term()}, [term()], keyword() ) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Send the input tensor dictionary.
NOTE: dst is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.Parameters
tensor_dict(%{optional(String.t()) => term()})dst(term() default: None)all_gather_group(term() | nil default: None)all_gather_tensors(term() default: None)
Returns
term()