View Source Nx.Defn.Graph (Nx v0.10.0)

A module for splitting Nx.Defn.Expr into stages.

This module is used to split an Nx.Defn.Expr into stages, which are then executed in a chain.

split/2 and t:Stage.t() describe how to split the graph and what's the expected result.

run/2 executes the given graph against the provided arguments in a sequential manner.

Summary

Functions

Executes the stage chain with the given arguments.

Splits the received Nx.Defn.Expr into stages given the rules.

Functions

Executes the stage chain with the given arguments.

Link to this function

split(expr, expr_split_fn)

View Source

Splits the received Nx.Defn.Expr into stages given the rules.

expr_split_fn is a function that receives an Nx.Tensor containing an Nx.Defn.Expr and returns true when a split must happen, and false otherwise.

Examples

iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2)
iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end)
iex> {out0} = stage0.expr
iex> out0
#Nx.Tensor<
  f32

  Nx.Defn.Expr
  parameter a:0   s32
  b = negate a    s32
  c = sin b       f32
>
iex> stage1.expr
#Nx.Tensor<
  f32

  Nx.Defn.Expr
  parameter a:1   f32
  parameter c:0   s32
  b = cos a       f32
  d = add b, c    f32
>