# `EMLX`
[🔗](https://github.com/elixir-nx/emlx/blob/v0.3.0/emlx/lib/emlx.ex#L136)

# `abs`

# `acos`

# `acosh`

# `add`

# `all`

# `allclose`

# `any`

# `arange`

# `argmax`

# `argmax`

# `argmin`

# `argmin`

# `argsort`

# `as_strided`

# `asin`

# `asinh`

# `astype`

# `atan2`

# `atan`

# `atanh`

# `bitwise_and`

# `bitwise_not`

# `bitwise_or`

# `bitwise_xor`

# `broadcast_to`

# `ceil`

# `clear_cache`

```elixir
@spec clear_cache() :: :ok
```

Clears the MLX memory cache, releasing unused GPU memory back to the system.

Useful after inference batches to prevent memory growth. Does not affect
tensors that are still referenced.

## Examples

    EMLX.clear_cache()
    #=> :ok

# `clip`

# `concatenate`

# `conjugate`

# `conv_general`

# `cos`

# `cosh`

# `cumulative_max`

# `cumulative_min`

# `cumulative_product`

# `cumulative_sum`

# `deallocate`

# `default_device`

Returns the default MLX device for this process.

Reads `:default_device` from the `:emlx` application environment, falling
back to `:gpu`. Override in tests or config via:

    Application.put_env(:emlx, :default_device, :cpu)

# `dequantize`

```elixir
@spec dequantize(Nx.Tensor.t()) :: Nx.Tensor.t()
```

Dequantize a quantized `Nx.Tensor` (created by `EMLX.quantize/2`) to a
dense float tensor by calling `mx::dequantize`.

# `dequantize`

Dequantizes packed weights to floating point.

Converts quantized weights back to their original floating point representation.
Useful for debugging and verification.

## Parameters
  - `w` - Quantized weights as uint32 (packed int4 values)
  - `scales` - Per-group scale factors
  - `biases` - Per-group zero points
  - `group_size` - Number of weights per group (default: 64)
  - `bits` - Quantization bits (default: 4)

# `divide`

# `einsum`

# `equal`

# `erf`

# `erf_inv`

# `eval`

Evaluates a (possibly lazy) MLX tensor by routing the work through an
`EMLX.CommandQueue`. Blocks the caller until the worker thread has
finished `mlx::core::eval/1` for this tensor.

Resolves the queue via `resolve_worker/1`:

  1. If the calling process has bound a queue with
     `EMLX.CommandQueue.with_queue/2`, that queue is used.
  2. Otherwise the application-default worker for the tensor's device
     (CPU or GPU) is used — see `EMLX.Application`.

# `exp`

# `expm1`

# `eye`

# `fast_layer_norm`

Fused layer normalisation (`mlx::fast::layer_norm`).

- `x`      — input tensor; normalised over the last axis.
- `weight` — `{hidden}` scale vector (gamma).
- `bias`   — `{hidden}` bias vector (beta).
- `eps`    — numerical stability constant (e.g. `1.0e-5`).

Prefer `EMLX.Fast.layer_norm/4` inside `defn`.

# `fast_layer_norm_no_bias`

Fused layer normalisation without bias (`mlx::fast::layer_norm`, weight-only variant).

Prefer `EMLX.Fast.layer_norm/3` inside `defn`.

# `fast_rms_norm`

Fused RMS normalisation (`mlx::fast::rms_norm`).

Single Metal shader. Normalises over the last axis of `x` and scales by
`weight`. Output shape and type match `x`.

Prefer `EMLX.Fast.rms_norm/3` inside `defn`; call this directly only from
eager (non-defn) code.

# `fast_rope`

Fused rotary position embedding (`mlx::fast::rope`).

Single Metal shader. Applies RoPE with a scalar position `offset`.

- `a`           — input `{B, ..., T, D}`
- `dims`        — number of feature dims to rotate (≤ last-axis size, must be even)
- `traditional` — `false` for split-half (Qwen3); `true` for interleaved
- `base`        — angular frequency base (e.g. 10_000 or 1_000_000)
- `scale`       — position scale (1.0 unless using NTK-aware scaling)
- `offset`      — integer token position (length of KV cache already filled)

Prefer `EMLX.Fast.rope/6` inside `defn`.

# `fast_rope_ids`

Fused RoPE with a per-batch offset array (`mlx::fast::rope`, array-offset overload).

Calls the `const array& offset` overload of `mlx::fast::rope`, where `offset`
has shape `{B}` — one starting position per batch example. Positions within each
example are assumed to be sequential: `[offset[b], offset[b]+1, ..., offset[b]+T-1]`.

Typically you build `offset` by slicing `position_ids[:, 0]` (first token's
position for each batch example) before calling this function.

Prefer `EMLX.Fast.rope_with_positions/6` inside `defn`.

# `fast_rope_positions`

Fused RoPE for arbitrary per-token `position_ids`.

Uses full `{B, T}` position IDs (not just an offset) and mirrors Bumblebee's
per-token cos/sin lookup, avoiding the sequential-offset assumption of
`fast_rope_ids`.

- `a`            — input tensor `{B, T, H, D}`.
- `dims`         — number of feature dims to rotate.
- `traditional`  — currently only `false` is supported.
- `base`         — angular frequency base (e.g. `1_000_000`).
- `scale`        — position scale.
- `position_ids` — `{B, T}` integer tensor.

# `fast_rope_with_freqs`

Fused RoPE with precomputed inv-frequency vector (`mlx::fast::rope`, freqs overload).

- `a`          — input tensor; shape `{B, T, ..., D}`.
- `dims`       — number of feature dims to rotate.
- `traditional`— `false` for split-half (Qwen3/Bumblebee); `true` for interleaved.
- `scale`      — position scale (typically `1.0` when using precomputed freqs).
- `offset`     — `{B}` per-batch starting position tensor.
- `freqs`      — `{dims/2}` precomputed inverse-frequency tensor.

Prefer `EMLX.Fast.rope_with_freqs/6` inside `defn`.

# `fast_sdpa`

Flash-attention style SDPA, no mask (`mlx::fast::scaled_dot_product_attention`).

GQA-native: `k`/`v` may have fewer heads than `q` — no pre-tiling needed.

- `q`     — `{B, N_q,  T_q,  D}`
- `k`     — `{B, N_kv, T_kv, D}`
- `v`     — `{B, N_kv, T_kv, D}`
- `scale` — scalar (typically `1 / sqrt(D)`)

Prefer `EMLX.Fast.scaled_dot_product_attention/4` inside `defn`.

# `fast_sdpa_causal`

Flash-attention SDPA with built-in causal mask (`mlx::fast::scaled_dot_product_attention`,
`mask_mode="causal"`).

MLX constructs the upper-triangular causal mask internally — no explicit mask
tensor required. GQA-native: `k`/`v` may have fewer heads than `q`.

- `q`     — `{B, N_q,  T_q,  D}`
- `k`     — `{B, N_kv, T_kv, D}`
- `v`     — `{B, N_kv, T_kv, D}`
- `scale` — scalar (typically `1 / sqrt(D)`)

Prefer `EMLX.Fast.scaled_dot_product_attention_causal/4` inside `defn`.

# `fast_sdpa_causal_key_masked`

Causal SDPA with a runtime `key_mask` check performed at the C++ level.

At the NIF level: evaluates `all(key_mask == 1)` (cheap for small/constant
tensors). If true → uses the pure causal Metal kernel (no mask allocation).
If false → builds a combined causal + key_mask additive float mask and calls
the masked kernel.

- `q`        — `{B, N_q,  T_q,  D}`
- `k`        — `{B, N_kv, T_kv, D}`
- `v`        — `{B, N_kv, T_kv, D}`
- `scale`    — scalar (typically `1 / sqrt(D)`)
- `key_mask` — `{B, T_kv}` boolean/int tensor (1 = attend, 0 = padding)

Prefer `EMLX.Fast.scaled_dot_product_attention_causal_key_masked/5` inside `defn`.

# `fast_sdpa_masked`

Flash-attention SDPA with an additive or boolean `mask`.

`mask` must be broadcast-compatible with `{B, N_q, T_q, T_kv}`.
Boolean `false` entries are treated as `-∞`.

Prefer `EMLX.Fast.scaled_dot_product_attention/5` inside `defn`.

# `fast_swiglu`

Fused SwiGLU activation: `silu(gate) * up` where `silu(x) = x * sigmoid(x)`.

- `gate` — gate tensor; silu is applied element-wise.
- `up`   — up-projection tensor; same shape as `gate`.

Output has the same shape and dtype as `gate`.

Prefer `EMLX.Fast.swiglu/2` inside `defn`.

# `fft2`

# `fft`

# `floor`

# `from_blob`

# `full`

# `gather`

# `greater`

# `greater_equal`

# `ifft2`

# `ifft`

# `imag`

# `is_infinity`

# `is_nan`

# `is_tensor`
*macro* 

# `isclose`

# `item`

Returns the scalar value of a 0-d tensor as a number.

Worker-routed: the NIF body calls `mlx::core::eval(*t)` and `t->item<T>()`,
both of which require running on the OS thread that owns the tensor's
stream encoder.

# `kv_cache_attention`

Fused KV cache update + variable-length SDPA in a single Metal command buffer.

Receives tensors in Bumblebee `{B, T, N, D}` convention. Internally transposes
to MLX `{B, N, T, D}` for `mlx::fast::scaled_dot_product_attention`, then
transposes the result back. Returns a 3-tuple of EMLX `{device, ref}` pairs.

- `q`       — `{B, T_q,   N_q,  D}` post-RoPE query
- `new_k`   — `{B, T_new, N_kv, D}` current key projection (post-RoPE)
- `new_v`   — `{B, T_new, N_kv, D}` current value projection
- `k_cache` — `{B, T_max, N_kv, D}` preallocated key buffer
- `v_cache` — `{B, T_max, N_kv, D}` preallocated value buffer
- `offset`  — integer, number of positions already in cache
- `scale`   — float, `1 / sqrt(head_dim)`

Returns `{{dev, attn_ref}, {dev, k_upd_ref}, {dev, v_upd_ref}}`.

# `kv_cache_attention_masked`

Like `kv_cache_attention/7` but also applies `key_mask` to exclude padding
positions from attention in both prefill and decode steps.

- `key_mask` — `{B, T_kv}` integer or boolean tensor with `1` = attend,
  `0` = skip (padding). Must cover exactly `valid_len = offset + T_new` positions.

The combined additive mask applies causal AND key_mask constraints without
calling `mlx::core::all()`, avoiding Metal sort-kernel compilation issues for
small tensor shapes.

Returns `{{dev, attn_ref}, {dev, k_upd_ref}, {dev, v_upd_ref}}`.

# `kv_cache_sdpa_update`

Fused KV cache update + SDPA for the native NIF loop (BNHD layout).

Accepts `q`, `new_k`, `new_v` already transposed to `{B, N, T, D}` and a
pre-allocated cache in `{B, N_kv, T_max, D}` layout.

Internally, the cache arrays are **move-extracted** from their ENIF resources
before `slice_update` so that MLX's donation optimisation fires at eval time:
the existing Metal buffer is reused in-place — no new allocation.

Returns `{{dev, attn_ref}, {dev, k_upd_ref}, {dev, v_upd_ref}}`.

# `left_shift`

# `less`

# `less_equal`

# `linalg_cholesky`

# `linalg_eigh`

# `linalg_inv`

# `linalg_lu`

# `linalg_pinv`

# `linalg_qr`

# `linalg_solve`

# `linalg_solve_triangular`

# `linalg_svd`

# `log1p`

# `log`

# `logical_and`

# `logical_not`

# `logical_or`

# `logical_xor`

# `max`

# `maximum`

# `memory_info`

Returns a map with current memory usage information.

Keys:
  * `:active_memory` - bytes currently allocated and in use
  * `:peak_memory` - highest active memory since last reset
  * `:cache_memory` - bytes in the allocator cache (freed but not returned to OS)

## Examples

    iex> info = EMLX.memory_info()
    iex> is_integer(info.active_memory) and is_integer(info.peak_memory) and is_integer(info.cache_memory)
    true

# `min`

# `minimum`

# `multiply`

# `negate`

# `not_equal`

# `ones`

# `pad`

# `pow`

# `product`

# `quantize`

```elixir
@spec quantize(
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()
```

Quantize a dense 2-D `Nx.Tensor` and return an annotated quantized tensor.

The returned tensor carries the original logical shape and type (e.g.
`{:s, 4}`). Its backend stores the packed uint32 data and a
`EMLX.Quantization.Config` with scales, biases, `group_size`, and `bits`.

## Options

* `:type` — storage type: `{:s, 2}`, `{:s, 4}` (default), or `{:s, 8}`.
* `:group_size` — 32, 64, or 128 (default 64). Must evenly divide the last
  dimension of `tensor`.

# `quantize`

Quantizes a floating point tensor to packed format.

Returns a tuple of `{quantized_weights, scales, biases}` where:
  - `quantized_weights` - Packed uint32 tensor (8 int4 values per uint32)
  - `scales` - Per-group scale factors
  - `biases` - Per-group zero points

## Parameters
  - `w` - Float tensor to quantize
  - `group_size` - Number of weights per group (default: 64)
  - `bits` - Quantization bits (default: 4)

# `quantized_matmul`

```elixir
@spec quantized_matmul(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Run `activation @ dequantize(qw)` using `mx::quantized_matmul`.

`qw` must be a quantized tensor produced by `EMLX.quantize/2`. Raises
`ArgumentError` if both arguments are quantized.

# `quantized_matmul`

Performs quantized matrix multiplication.

This is the key operation for efficient 4-bit inference. It multiplies `x` with
quantized weights `w` (packed as uint32), using scales and biases for
dequantization during the computation.

## Parameters
  - `x` - Input tensor (e.g., {batch, seq, hidden})
  - `w` - Quantized weights as uint32 (8 int4 values packed per uint32)
  - `scales` - Per-group scale factors (bfloat16)
  - `biases` - Per-group zero points (bfloat16)
  - `transpose` - Whether to transpose weights (default: true)
  - `group_size` - Number of weights per scale/bias group (default: 64)
  - `bits` - Quantization bits (default: 4)

# `quotient`

# `real`

# `remainder`

# `reset_peak_memory`

```elixir
@spec reset_peak_memory() :: :ok
```

Resets the peak memory counter to zero.

## Examples

    EMLX.reset_peak_memory()
    #=> :ok

# `reshape`

# `right_shift`

# `round`

# `rsqrt`

# `scalar_tensor`

# `scalar_type`

# `scatter`

# `scatter_add`

# `set_cache_limit`

Sets the cache limit in bytes. Returns the previous limit.

When cached memory exceeds this limit, it will be reclaimed on the next
allocation. Set to 0 to disable caching entirely.

## Examples

    prev = EMLX.set_cache_limit(500_000_000)
    EMLX.set_cache_limit(prev)

# `set_memory_limit`

Sets the memory limit in bytes. Returns the previous limit.

The memory limit is a guideline for maximum memory usage during graph
evaluation. Defaults to 1.5× the device's recommended working set size.

## Examples

    prev = EMLX.set_memory_limit(8_000_000_000)
    EMLX.set_memory_limit(prev)

# `shape`

# `shm_unlink`

Unlinks a POSIX shared-memory segment by its handle name.

Call this if the receiver never opens the `%Nx.Pointer{kind: :ipc}` returned
by `Nx.to_pointer/2` — otherwise the shm name persists until the next reboot.
Safe to call even if the segment has already been unlinked (ENOENT is ignored).

# `sigmoid`

# `sign`

# `sin`

# `sinh`

# `slice`

# `slice_update`

# `sort`

# `sqrt`

# `squeeze`

# `stack`

# `strides`

# `subtract`

# `sum`

# `take`

# `take_along_axis`

# `tan`

# `tanh`

# `tensor_data_ptr`

Returns `{address, byte_size}` for the tensor's raw GPU buffer.

Evals the tensor first (same pattern as `to_blob/1`). The pointer is valid
as long as no further MLX evaluation is triggered on the array and the
Elixir tensor term is kept alive. On Apple Silicon the address is accessible
from both CPU and GPU due to unified memory.

# `tensor_to_shm`

Copies tensor data into a new POSIX shared-memory segment and returns
`{shm_name, byte_size}`.

Note: this involves a **memcpy** — MLX arrays are immutable so zero-copy
cross-process sharing is not possible.  `permissions` is a Unix mode integer
(e.g. `0o400` for owner-read-only).

The shm name persists until the receiver opens and unlinks it (which
`EMLX.NIF.array_from_shm/4` does automatically).

# `tensordot`

# `to_blob`

# `to_blob`

# `to_nx`

Converts an EMLX device ref back to an Nx.Tensor.

## Example

    result_ref = EMLX.some_operation(input)
    result_tensor = EMLX.to_nx(result_ref)

# `transpose`

# `tri_inv`

# `view`

# `where`

# `window_scatter_max`

# `window_scatter_min`

---

*Consult [api-reference.md](api-reference.md) for complete listing*
