Base class for device-specific communicator.
It can use the cpu_group to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the device_group will also be given.
Summary
Functions
Python method DeviceCommunicatorBase.all_gather.
Python method DeviceCommunicatorBase.all_gatherv.
Python method DeviceCommunicatorBase.all_reduce.
Combine the hidden states and router logits from the appropriate device.
Python method DeviceCommunicatorBase.destroy.
Dispatch the hidden states and router logits to the appropriate device.
NOTE: We assume that the input tensor is on the same device across
Initialize self. See help(type(self)) for accurate signature.
Prepare the communication buffer for the model.
Receives a tensor from the source rank.
Python method DeviceCommunicatorBase.reduce_scatter.
Python method DeviceCommunicatorBase.reduce_scatterv.
Sends a tensor to the destination rank in a blocking way
Types
Functions
@spec all_gather(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method DeviceCommunicatorBase.all_gather.
Parameters
input_(term())dim(integer() default: -1)
Returns
term()
@spec all_gatherv(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method DeviceCommunicatorBase.all_gatherv.
Parameters
input_(term())dim(integer() default: 0)sizes(term() default: None)
Returns
term()
@spec all_reduce(SnakeBridge.Ref.t(), term(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method DeviceCommunicatorBase.all_reduce.
Parameters
input_(term())
Returns
term()
@spec combine(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
Parameters
hidden_states(term())is_sequence_parallel(boolean() default: False)
Returns
term()
@spec destroy( SnakeBridge.Ref.t(), keyword() ) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method DeviceCommunicatorBase.destroy.
Returns
term()
@spec dispatch(SnakeBridge.Ref.t(), term(), term(), [term()], keyword()) :: {:ok, {term(), term()}} | {:error, Snakepit.Error.t()}
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
Parameters
hidden_states(term())router_logits(term())is_sequence_parallel(boolean() default: False)extra_tensors(term() default: None)
Returns
{term(), term()}
@spec gather(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: dst is the local rank of the destination rank.
Parameters
input_(term())dst(integer() default: 0)dim(integer() default: -1)
Returns
term()
@spec new(term(), [term()], keyword()) :: {:ok, SnakeBridge.Ref.t()} | {:error, Snakepit.Error.t()}
Initialize self. See help(type(self)) for accurate signature.
Parameters
cpu_group(term())device(term() default: None)device_group(term() default: None)unique_name(String.t() default: '')
@spec prepare_communication_buffer_for_model(SnakeBridge.Ref.t(), term(), keyword()) :: {:ok, nil} | {:error, Snakepit.Error.t()}
Prepare the communication buffer for the model.
Parameters
model(term())
Returns
nil
@spec recv(SnakeBridge.Ref.t(), term(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Receives a tensor from the source rank.
Parameters
size(term())dtype(term())src(term() default: None)
Returns
term()
@spec reduce_scatter(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method DeviceCommunicatorBase.reduce_scatter.
Parameters
input_(term())dim(integer() default: -1)
Returns
term()
@spec reduce_scatterv(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method DeviceCommunicatorBase.reduce_scatterv.
Parameters
input_(term())dim(integer() default: -1)sizes(term() default: None)
Returns
term()
@spec send(SnakeBridge.Ref.t(), term(), [term()], keyword()) :: {:ok, nil} | {:error, Snakepit.Error.t()}
Sends a tensor to the destination rank in a blocking way
Parameters
tensor(term())dst(term() default: None)
Returns
nil