View Source Nx.Serving behaviour (Nx v0.4.1)
Serving encapsulates client and server work to perform batched requests.
Serving can be executed on the fly, without starting a server, but most often it is used to run servers that batch requests until a given size or timeout is reached.
inline-serverless-workflow
Inline/serverless workflow
First, let's define a simple numerical definition function:
defmodule MyDefn do
import Nx.Defn
defnp print_and_multiply(x) do
print_value({:debug, x})
x * 2
end
end
The function prints the given tensor and double its contents.
We can use new/1
to create a serving that will return a JITted
or compiled function to execute on batches of tensors:
iex> serving = Nx.Serving.new(fn -> Nx.Defn.jit(&print_and_multiply/1) end)
iex> batch = Nx.Batch.stack([Nx.tensor([1, 2, 3])])
iex> Nx.Serving.run(serving, batch)
{:debug, #Nx.Tensor<
s64[1][3]
[
[1, 2, 3]
]
>}
#Nx.Tensor<
s64[1][3]
[
[2, 4, 6]
]
>
You should see two values printed. The first is the result of
Nx.Defn.Kernel.print_value/1
, which shows the tensor that was
actually part of the computation and how it was batched.
Then we see the result of the computation.
When defining a Nx.Serving
, we can also customize how the data is
batched by using the client_preprocessing
as well as the result by
using client_postprocessing
hooks. Let's give it another try:
iex> serving = (
...> Nx.Serving.new(fn -> Nx.Defn.jit(&print_and_multiply/1) end)
...> |> Nx.Serving.client_preprocessing(fn input -> {Nx.Batch.stack(input), :client_info} end)
...> |> Nx.Serving.client_postprocessing(&{&1, &2, &3})
...> )
iex> Nx.Serving.run(serving, [Nx.tensor([1, 2]), Nx.tensor([3, 4])])
{:debug, #Nx.Tensor<
s64[2][2]
[
[1, 2],
[3, 4]
]
>}
{#Nx.Tensor<
s64[2][2]
[
[2, 4],
[6, 8]
]
>,
:server_info,
:client_info}
You can see the results are a bit different now. First of all, notice we
were able to run the serving passing a list of tensors. Our custom
client_preprocessing
function stacks those tensors into a batch of two
entries and returns a tuple with a Nx.Batch
struct and additional client
information which we represent as the atom :client_info
. The default
client preprocessing simply enforces a batch was given and returns no client
information.
Then the result is a {..., ..., ...}
tuple, returned by the client
postprocessing function, containing the result, the server information
(which we will learn later how to customize it), and the client information.
From this, we can infer the default implementation of client_postprocessing
simply returns the result, discarding the server and client information.
So far, Nx.Serving
has not given us much. It has simply encapsulated the
execution of a function. Its full power comes when we start running our own
Nx.Serving
process. That's when we will also learn why we have a client_
prefix in some of the function names.
stateful-process-workflow
Stateful/process workflow
Nx.Serving
allows us to define a process that will batch requests up to
a certain size or within a certain time. To do so, we need to start a
Nx.Serving
process with a serving inside a supervision tree:
children = [
{Nx.Serving,
serving: Nx.Serving.new(Nx.Defn.jit(&print_and_multiply/1)),
name: MyServing,
batch_size: 10,
batch_timeout: 100}
]
Supervisor.start_child(children, strategy: :one_for_one)
Now you can send batched runs to said process:
iex> batch = Nx.Batch.stack([Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6])])
iex> Nx.Serving.batched_run(MyServing, batch)
{:debug, #Nx.Tensor<
s64[2][3]
[
[1, 2, 3],
[4, 5, 6]
]
>}
#Nx.Tensor<
s64[2][3]
[
[2, 4, 6],
[8, 10, 12]
]
>
In the example, we pushed a batch of 2 and eventually got a reply. The process will wait up for requests from other processes, up to 100 milliseconds or until it gets 10 entries. Then it merges all batches together and, once the result is computed, it slices and distributed those responses to each caller.
If there is any client_preprocessing
function, it will be executed
before the batch is sent to the server. If there is any client_postprocessing
function, it will be executed after getting the response from the
server.
module-based-serving
Module-based serving
In the examples so far, we have been using the default version of
Nx.Serving
, which executes the given function for each batch.
However, we can also use new/2
to start a module-based version of
Nx.Serving
, which acts similar to an Elixir GenServer
and gives
us more control over both inline and process workflows. A simple
module implementation of a Nx.Serving
could look like this:
defmodule MyServing do
@behaviour Nx.Serving
defnp print_and_multiply(x) do
print_value({:debug, x})
x * 2
end
@impl true
def init(_inline_or_process, :unused_arg) do
{:ok, Nx.Defn.jit(&print_and_multiply/1)}
end
@impl true
def handle_batch(batch, function) do
{:execute, fn -> {function.(batch), :server_info} end, function}
end
end
It has two functions: init/2
, which receives the type of serving
(:inline
or :process
) and the serving argument. In this step,
we capture print_and_multiply/1
as a jitted function.
The second function is called handle_batch/2
. This function
receives a Nx.Batch
and it must return a function to execute.
The function itself must return a two element-tuple: the batched
results and some server information. The server information can
be any value and we set it to the atom :server_info
.
Now let's give it a try by defining a serving with our module and then running it on a batch:
iex> serving = Nx.Serving.new(MyServing, :unused_arg)
iex> batch = Nx.Batch.stack([Nx.tensor([1, 2, 3])])
iex> Nx.Serving.run(serving, batch)
{:debug, #Nx.Tensor<
s64[1][3]
[
[1, 2, 3]
]
>}
#Nx.Tensor<
s64[1][3]
[
[2, 4, 6]
]
>
From here on, you use start_link/1
to start this serving in your
supervision and even customize client_preprocessing/1
and
client_postprocessing/1
callbacks to this serving, as seen in the
previous sections.
Link to this section Summary
Callbacks
Receives a batch and returns a function to execute the batch.
The callback used to initialize the serving.
Functions
Runs the given input
on the process given by name
.
Sets the client postprocessing function.
Sets the client preprocessing function.
Creates a new function serving.
Creates a new module-based serving.
Runs serving
with the given input
inline with the current process.
Starts a Nx.Serving
process to batch requests to a given serving.
Link to this section Types
@type client_info() :: term()
@type client_postprocessing() :: (Nx.Container.t(), metadata(), client_info() -> term())
@type client_preprocessing() :: (term() -> {Nx.Batch.t(), client_info()})
@type metadata() :: term()
@type t() :: %Nx.Serving{ arg: term(), client_postprocessing: client_postprocessing(), client_preprocessing: client_preprocessing(), module: atom(), process_options: term() }
Link to this section Callbacks
@callback handle_batch(Nx.Batch.t(), state) :: {:execute, (-> {Nx.Container.t(), metadata()}), state} when state: term()
Receives a batch and returns a function to execute the batch.
In case of serving processes, the function is executed is an separate process.
The callback used to initialize the serving.
The first argument reveals if the serving is executed inline,
such as by calling run/2
, by started with the process.
The second argument is the serving argument given to new/2
.
It must return {:ok, state}
, where the state
can be any term.
Link to this section Functions
Runs the given input
on the process given by name
.
The process name
will batch requests and send a response
either when the batch is full or on timeout. See the module
documentation for more information.
Note you cannot batch an input
larger than the configured
:batch_size
in the server.
Sets the client postprocessing function.
The default implementation returns the first element given to the function.
Sets the client preprocessing function.
The default implementation creates a single element batch
with the given argument and is equivalent to &Nx.Batch.stack([&1])
.
Creates a new function serving.
It expects a function that returns a JITted (via Nx.Defn.jit/2
)
or compiled (via Nx.Defn.compile/3
) one-arity function as argument.
The JITted/compiled function will be called with the arguments
returned by the client_preprocessing
callback.
A second argument called process_options
, which is optional,
can be given to customize the options when starting the serving
under a process.
Creates a new module-based serving.
It expects a module and an argument that is given to its init
callback.
A third argument called process_options
, which is optional,
can be given to customize the options when starting the serving
under a process.
Runs serving
with the given input
inline with the current process.
Starts a Nx.Serving
process to batch requests to a given serving.
options
Options
:name
- an atom with the name of the process:serving
- aNx.Serving
struct with the serving configuration:batch_size
- the maximum size to batch for. A value is first read from theNx.Serving
struct and then it falls back to this option (which defaults to1
):batch_timeout
- the maximum time to wait, in milliseconds, before executing the batch. A value is first read from theNx.Serving
struct and then it falls back to this option (which defaults to100
ms)