View Source Nx.Defn.Graph (Nx v0.10.0)
A module for splitting Nx.Defn.Expr into stages.
This module is used to split an Nx.Defn.Expr into stages, which are then
executed in a chain.
split/2 and t:Stage.t() describe how to split
the graph and what's the expected result.
run/2 executes the given graph against the provided arguments in a sequential manner.
Summary
Functions
Executes the stage chain with the given arguments.
Splits the received Nx.Defn.Expr into stages given the rules.
Functions
Executes the stage chain with the given arguments.
Splits the received Nx.Defn.Expr into stages given the rules.
expr_split_fn is a function that receives an Nx.Tensor containing an Nx.Defn.Expr
and returns true when a split must happen, and false otherwise.
Examples
iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2)
iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end)
iex> {out0} = stage0.expr
iex> out0
#Nx.Tensor<
f32
Nx.Defn.Expr
parameter a:0 s32
b = negate a s32
c = sin b f32
>
iex> stage1.expr
#Nx.Tensor<
f32
Nx.Defn.Expr
parameter a:1 f32
parameter c:0 s32
b = cos a f32
d = add b, c f32
>