Vllm.Distributed.DeviceCommunicatorBase (VLLM v0.3.0)

Copy Markdown View Source

Base class for device-specific communicator.

It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Summary

Functions

Python method DeviceCommunicatorBase.all_gather.

Python method DeviceCommunicatorBase.all_gatherv.

Python method DeviceCommunicatorBase.all_reduce.

Combine the hidden states and router logits from the appropriate device.

Python method DeviceCommunicatorBase.destroy.

Dispatch the hidden states and router logits to the appropriate device.

NOTE: We assume that the input tensor is on the same device across

Initialize self. See help(type(self)) for accurate signature.

Prepare the communication buffer for the model.

Receives a tensor from the source rank.

Python method DeviceCommunicatorBase.reduce_scatter.

Python method DeviceCommunicatorBase.reduce_scatterv.

Sends a tensor to the destination rank in a blocking way

Types

t()

@opaque t()

Functions

all_gather(ref, input_, args, opts \\ [])

@spec all_gather(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Python method DeviceCommunicatorBase.all_gather.

Parameters

  • input_ (term())
  • dim (integer() default: -1)

Returns

  • term()

all_gatherv(ref, input_, args, opts \\ [])

@spec all_gatherv(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Python method DeviceCommunicatorBase.all_gatherv.

Parameters

  • input_ (term())
  • dim (integer() default: 0)
  • sizes (term() default: None)

Returns

  • term()

all_reduce(ref, input_, opts \\ [])

@spec all_reduce(SnakeBridge.Ref.t(), term(), keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Python method DeviceCommunicatorBase.all_reduce.

Parameters

  • input_ (term())

Returns

  • term()

combine(ref, hidden_states, args, opts \\ [])

@spec combine(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Combine the hidden states and router logits from the appropriate device.

This is a no-op in the base class.

Parameters

  • hidden_states (term())
  • is_sequence_parallel (boolean() default: False)

Returns

  • term()

destroy(ref, opts \\ [])

@spec destroy(
  SnakeBridge.Ref.t(),
  keyword()
) :: {:ok, term()} | {:error, Snakepit.Error.t()}

Python method DeviceCommunicatorBase.destroy.

Returns

  • term()

dispatch(ref, hidden_states, router_logits, args, opts \\ [])

@spec dispatch(SnakeBridge.Ref.t(), term(), term(), [term()], keyword()) ::
  {:ok, {term(), term()}} | {:error, Snakepit.Error.t()}

Dispatch the hidden states and router logits to the appropriate device.

This is a no-op in the base class.

Parameters

  • hidden_states (term())
  • router_logits (term())
  • is_sequence_parallel (boolean() default: False)
  • extra_tensors (term() default: None)

Returns

  • {term(), term()}

gather(ref, input_, args, opts \\ [])

@spec gather(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

NOTE: We assume that the input tensor is on the same device across

all the ranks. NOTE: dst is the local rank of the destination rank.

Parameters

  • input_ (term())
  • dst (integer() default: 0)
  • dim (integer() default: -1)

Returns

  • term()

new(cpu_group, args, opts \\ [])

@spec new(term(), [term()], keyword()) ::
  {:ok, SnakeBridge.Ref.t()} | {:error, Snakepit.Error.t()}

Initialize self. See help(type(self)) for accurate signature.

Parameters

  • cpu_group (term())
  • device (term() default: None)
  • device_group (term() default: None)
  • unique_name (String.t() default: '')

prepare_communication_buffer_for_model(ref, model, opts \\ [])

@spec prepare_communication_buffer_for_model(SnakeBridge.Ref.t(), term(), keyword()) ::
  {:ok, nil} | {:error, Snakepit.Error.t()}

Prepare the communication buffer for the model.

Parameters

  • model (term())

Returns

  • nil

recv(ref, size, dtype, args, opts \\ [])

@spec recv(SnakeBridge.Ref.t(), term(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Receives a tensor from the source rank.

Parameters

  • size (term())
  • dtype (term())
  • src (term() default: None)

Returns

  • term()

reduce_scatter(ref, input_, args, opts \\ [])

@spec reduce_scatter(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Python method DeviceCommunicatorBase.reduce_scatter.

Parameters

  • input_ (term())
  • dim (integer() default: -1)

Returns

  • term()

reduce_scatterv(ref, input_, args, opts \\ [])

@spec reduce_scatterv(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, term()} | {:error, Snakepit.Error.t()}

Python method DeviceCommunicatorBase.reduce_scatterv.

Parameters

  • input_ (term())
  • dim (integer() default: -1)
  • sizes (term() default: None)

Returns

  • term()

send(ref, tensor, args, opts \\ [])

@spec send(SnakeBridge.Ref.t(), term(), [term()], keyword()) ::
  {:ok, nil} | {:error, Snakepit.Error.t()}

Sends a tensor to the destination rank in a blocking way

Parameters

  • tensor (term())
  • dst (term() default: None)

Returns

  • nil