View Source Nx.Serving behaviour (Nx v0.4.1)

Serving encapsulates client and server work to perform batched requests.

Serving can be executed on the fly, without starting a server, but most often it is used to run servers that batch requests until a given size or timeout is reached.

inline-serverless-workflow

Inline/serverless workflow

First, let's define a simple numerical definition function:

defmodule MyDefn do
  import Nx.Defn

  defnp print_and_multiply(x) do
    print_value({:debug, x})
    x * 2
  end
end

The function prints the given tensor and double its contents. We can use new/1 to create a serving that will return a JITted or compiled function to execute on batches of tensors:

iex> serving = Nx.Serving.new(fn -> Nx.Defn.jit(&print_and_multiply/1) end)
iex> batch = Nx.Batch.stack([Nx.tensor([1, 2, 3])])
iex> Nx.Serving.run(serving, batch)
{:debug, #Nx.Tensor<
  s64[1][3]
  [
    [1, 2, 3]
  ]
>}
#Nx.Tensor<
  s64[1][3]
  [
    [2, 4, 6]
  ]
>

You should see two values printed. The first is the result of Nx.Defn.Kernel.print_value/1, which shows the tensor that was actually part of the computation and how it was batched. Then we see the result of the computation.

When defining a Nx.Serving, we can also customize how the data is batched by using the client_preprocessing as well as the result by using client_postprocessing hooks. Let's give it another try:

iex> serving = (
...>   Nx.Serving.new(fn -> Nx.Defn.jit(&print_and_multiply/1) end)
...>   |> Nx.Serving.client_preprocessing(fn input -> {Nx.Batch.stack(input), :client_info} end)
...>   |> Nx.Serving.client_postprocessing(&{&1, &2, &3})
...> )
iex> Nx.Serving.run(serving, [Nx.tensor([1, 2]), Nx.tensor([3, 4])])
{:debug, #Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>}
{#Nx.Tensor<
   s64[2][2]
   [
     [2, 4],
     [6, 8]
   ]
 >,
 :server_info,
 :client_info}

You can see the results are a bit different now. First of all, notice we were able to run the serving passing a list of tensors. Our custom client_preprocessing function stacks those tensors into a batch of two entries and returns a tuple with a Nx.Batch struct and additional client information which we represent as the atom :client_info. The default client preprocessing simply enforces a batch was given and returns no client information.

Then the result is a {..., ..., ...} tuple, returned by the client postprocessing function, containing the result, the server information (which we will learn later how to customize it), and the client information. From this, we can infer the default implementation of client_postprocessing simply returns the result, discarding the server and client information.

So far, Nx.Serving has not given us much. It has simply encapsulated the execution of a function. Its full power comes when we start running our own Nx.Serving process. That's when we will also learn why we have a client_ prefix in some of the function names.

stateful-process-workflow

Stateful/process workflow

Nx.Serving allows us to define a process that will batch requests up to a certain size or within a certain time. To do so, we need to start a Nx.Serving process with a serving inside a supervision tree:

children = [
  {Nx.Serving,
   serving: Nx.Serving.new(Nx.Defn.jit(&print_and_multiply/1)),
   name: MyServing,
   batch_size: 10,
   batch_timeout: 100}
]

Supervisor.start_child(children, strategy: :one_for_one)

Now you can send batched runs to said process:

iex> batch = Nx.Batch.stack([Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6])])
iex> Nx.Serving.batched_run(MyServing, batch)
{:debug, #Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>}
#Nx.Tensor<
  s64[2][3]
  [
    [2, 4, 6],
    [8, 10, 12]
  ]
>

In the example, we pushed a batch of 2 and eventually got a reply. The process will wait up for requests from other processes, up to 100 milliseconds or until it gets 10 entries. Then it merges all batches together and, once the result is computed, it slices and distributed those responses to each caller.

If there is any client_preprocessing function, it will be executed before the batch is sent to the server. If there is any client_postprocessing function, it will be executed after getting the response from the server.

module-based-serving

Module-based serving

In the examples so far, we have been using the default version of Nx.Serving, which executes the given function for each batch.

However, we can also use new/2 to start a module-based version of Nx.Serving, which acts similar to an Elixir GenServer and gives us more control over both inline and process workflows. A simple module implementation of a Nx.Serving could look like this:

defmodule MyServing do
  @behaviour Nx.Serving

  defnp print_and_multiply(x) do
    print_value({:debug, x})
    x * 2
  end

  @impl true
  def init(_inline_or_process, :unused_arg) do
    {:ok, Nx.Defn.jit(&print_and_multiply/1)}
  end

  @impl true
  def handle_batch(batch, function) do
    {:execute, fn -> {function.(batch), :server_info} end, function}
  end
end

It has two functions: init/2, which receives the type of serving (:inline or :process) and the serving argument. In this step, we capture print_and_multiply/1as a jitted function.

The second function is called handle_batch/2. This function receives a Nx.Batch and it must return a function to execute. The function itself must return a two element-tuple: the batched results and some server information. The server information can be any value and we set it to the atom :server_info.

Now let's give it a try by defining a serving with our module and then running it on a batch:

iex> serving = Nx.Serving.new(MyServing, :unused_arg)
iex> batch = Nx.Batch.stack([Nx.tensor([1, 2, 3])])
iex> Nx.Serving.run(serving, batch)
{:debug, #Nx.Tensor<
  s64[1][3]
  [
    [1, 2, 3]
  ]
>}
#Nx.Tensor<
  s64[1][3]
  [
    [2, 4, 6]
  ]
>

From here on, you use start_link/1 to start this serving in your supervision and even customize client_preprocessing/1 and client_postprocessing/1 callbacks to this serving, as seen in the previous sections.

Link to this section Summary

Callbacks

Receives a batch and returns a function to execute the batch.

The callback used to initialize the serving.

Functions

Runs the given input on the process given by name.

Sets the client postprocessing function.

Sets the client preprocessing function.

Creates a new function serving.

Creates a new module-based serving.

Runs serving with the given input inline with the current process.

Starts a Nx.Serving process to batch requests to a given serving.

Link to this section Types

@type client_info() :: term()
Link to this type

client_postprocessing()

View Source
@type client_postprocessing() ::
  (Nx.Container.t(), metadata(), client_info() -> term())
Link to this type

client_preprocessing()

View Source
@type client_preprocessing() :: (term() -> {Nx.Batch.t(), client_info()})
@type metadata() :: term()
@type t() :: %Nx.Serving{
  arg: term(),
  client_postprocessing: client_postprocessing(),
  client_preprocessing: client_preprocessing(),
  module: atom(),
  process_options: term()
}

Link to this section Callbacks

@callback handle_batch(Nx.Batch.t(), state) ::
  {:execute, (-> {Nx.Container.t(), metadata()}), state}
when state: term()

Receives a batch and returns a function to execute the batch.

In case of serving processes, the function is executed is an separate process.

@callback init(type :: :inline | :process, arg :: term()) :: {:ok, state :: term()}

The callback used to initialize the serving.

The first argument reveals if the serving is executed inline, such as by calling run/2, by started with the process. The second argument is the serving argument given to new/2.

It must return {:ok, state}, where the state can be any term.

Link to this section Functions

Link to this function

batched_run(name, input)

View Source

Runs the given input on the process given by name.

The process name will batch requests and send a response either when the batch is full or on timeout. See the module documentation for more information.

Note you cannot batch an input larger than the configured :batch_size in the server.

Link to this function

client_postprocessing(serving, function)

View Source

Sets the client postprocessing function.

The default implementation returns the first element given to the function.

Link to this function

client_preprocessing(serving, function)

View Source

Sets the client preprocessing function.

The default implementation creates a single element batch with the given argument and is equivalent to &Nx.Batch.stack([&1]).

Link to this function

new(function, process_options \\ [])

View Source

Creates a new function serving.

It expects a function that returns a JITted (via Nx.Defn.jit/2) or compiled (via Nx.Defn.compile/3) one-arity function as argument. The JITted/compiled function will be called with the arguments returned by the client_preprocessing callback.

A second argument called process_options, which is optional, can be given to customize the options when starting the serving under a process.

Link to this function

new(module, arg, process_options)

View Source

Creates a new module-based serving.

It expects a module and an argument that is given to its init callback.

A third argument called process_options, which is optional, can be given to customize the options when starting the serving under a process.

Runs serving with the given input inline with the current process.

Starts a Nx.Serving process to batch requests to a given serving.

options

Options

  • :name - an atom with the name of the process

  • :serving - a Nx.Serving struct with the serving configuration

  • :batch_size - the maximum size to batch for. A value is first read from the Nx.Serving struct and then it falls back to this option (which defaults to 1)

  • :batch_timeout - the maximum time to wait, in milliseconds, before executing the batch. A value is first read from the Nx.Serving struct and then it falls back to this option (which defaults to 100ms)