View Source Nx.Defn.Graph (Nx v0.10.0)
A module for splitting Nx.Defn.Expr
into stages.
This module is used to split an Nx.Defn.Expr
into stages, which are then
executed in a chain.
split/2
and t:Stage.t()
describe how to split
the graph and what's the expected result.
run/2
executes the given graph against the provided arguments in a sequential manner.
Summary
Functions
Executes the stage chain with the given arguments.
Splits the received Nx.Defn.Expr into stages given the rules.
Functions
Executes the stage chain with the given arguments.
Splits the received Nx.Defn.Expr into stages given the rules.
expr_split_fn
is a function that receives an Nx.Tensor
containing an Nx.Defn.Expr
and returns true
when a split must happen, and false
otherwise.
Examples
iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2)
iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end)
iex> {out0} = stage0.expr
iex> out0
#Nx.Tensor<
f32
Nx.Defn.Expr
parameter a:0 s32
b = negate a s32
c = sin b f32
>
iex> stage1.expr
#Nx.Tensor<
f32
Nx.Defn.Expr
parameter a:1 f32
parameter c:0 s32
b = cos a f32
d = add b, c f32
>