# `Nx.Defn.Compiler`
[🔗](https://github.com/elixir-nx/nx/blob/v0.11.0/nx/lib/nx/defn/compiler.ex#L1)

The specification and helper functions for custom `defn` compilers.

# `__compile__`

```elixir
@callback __compile__(
  key :: term(),
  vars :: vars,
  fun :: (vars -&gt; Nx.Container.t()),
  opts :: keyword()
) :: ([[Nx.Tensor.t()]] -&gt; [Nx.Container.t()])
when vars: [Nx.Container.t()]
```

Callback for compilation.

It receives an opaque `key` used for caching, the function
`vars`, the function `fun` which builds a defn expression,
and the compiler options. It must call `fun` with the `vars`
as arguments.

It returns a function that receives a list of arguments and
returns a list of results.

The callback uses double underscores so it can be defined
at root modules without affecting the module's main API.

# `__jit__`

```elixir
@callback __jit__(
  key :: term(),
  vars,
  fun :: (vars -&gt; Nx.Container.t()),
  args_list :: [[(-&gt; Nx.Tensor.t())]],
  opts :: keyword()
) :: [Nx.Container.t()]
when vars: [Nx.Container.t()]
```

Callback for compilation.

It receives an opaque `key` used for caching, the function
`vars`, the function `fun` which builds a defn expression,
a list of argument list in `args_list`, and the compiler options.

It must call `fun` with the `vars` as arguments. Note the `key`
does not include the `vars` in its cache. Therefore, if you want
to cache the result of `fun.(vars)`, you likely want to include
the vars in the cache key. `vars` is a list of containers expressions.

Once the expression is built and compiled, it must be invoked
for each list of arguments in `args_list`. In a nutshell, `vars`
are used to build the expression from `fun` which is then
invoked for each list of arguments in `args_list`. All lists
in `args_list` are guaranteed to be flat lists of the same length,
containing zero-arity functions that return tensors of the same type,
shape, and name.

The callback uses double underscores so it can be defined
at root modules without affecting the module's main API.

# `__partitions_options__`

```elixir
@callback __partitions_options__(keyword()) :: [keyword()]
```

Receives a keyword list of compiler options and
returns a list of compiler options, each to run
on a separate partition/device.

# `__shard_jit__`
*optional* 

```elixir
@callback __shard_jit__(
  key :: term(),
  mesh :: Nx.Mesh.t(),
  [vars],
  fun :: (vars -&gt; Nx.Container.t()),
  args_list :: [[(-&gt; Nx.Tensor.t())]],
  opts :: keyword()
) :: [Nx.Container.t()]
when vars: [Nx.Container.t()]
```

Callback for compilation of a parallelizable computation.

Its main purpose is to compile a function for a given `Nx.Mesh`.

Receives an opaque `key` used for caching, a `mesh`, a list of `vars`
in `[vars]`, the function `fun` which builds a defn expression, a list of
argument lists in `args_list`, and the compiler options.

Using `[vars]` instead of a single `vars` allows the compiler to keep one
set of abstract parameters per shard or logical device in the mesh. This is useful
when the tensors are already divided into shards.

# `__to_backend__`

```elixir
@callback __to_backend__(keyword()) :: {module(), keyword()}
```

Receives a keyword list of compiler options and returns a backend
with options that corresponds to the same allocation.

The backend is expected to match what would be returned from a
computation defined by the compiler.

# `current`

Returns the current compiler.

Returns nil if we are not inside `defn`.

# `defn?`

Returns if we are inside `defn` at _compilation time_.

This would be invoked inside a macro that has specific `defn` logic.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
