Summary

Functions

Clears the MLX memory cache, releasing unused GPU memory back to the system.

Returns the default MLX device for this process.

Dequantize a quantized Nx.Tensor (created by EMLX.quantize/2) to a dense float tensor by calling mx::dequantize.

Dequantizes packed weights to floating point.

Evaluates a (possibly lazy) MLX tensor by routing the work through an EMLX.CommandQueue. Blocks the caller until the worker thread has finished mlx::core::eval/1 for this tensor.

Fused layer normalisation (mlx::fast::layer_norm).

Fused layer normalisation without bias (mlx::fast::layer_norm, weight-only variant).

Fused RMS normalisation (mlx::fast::rms_norm).

Fused rotary position embedding (mlx::fast::rope).

Fused RoPE with a per-batch offset array (mlx::fast::rope, array-offset overload).

Fused RoPE for arbitrary per-token position_ids.

Fused RoPE with precomputed inv-frequency vector (mlx::fast::rope, freqs overload).

Flash-attention style SDPA, no mask (mlx::fast::scaled_dot_product_attention).

Flash-attention SDPA with built-in causal mask (mlx::fast::scaled_dot_product_attention, mask_mode="causal").

Causal SDPA with a runtime key_mask check performed at the C++ level.

Flash-attention SDPA with an additive or boolean mask.

Fused SwiGLU activation: silu(gate) * up where silu(x) = x * sigmoid(x).

Returns the scalar value of a 0-d tensor as a number.

Fused KV cache update + variable-length SDPA in a single Metal command buffer.

Like kv_cache_attention/7 but also applies key_mask to exclude padding positions from attention in both prefill and decode steps.

Fused KV cache update + SDPA for the native NIF loop (BNHD layout).

Returns a map with current memory usage information.

Quantize a dense 2-D Nx.Tensor and return an annotated quantized tensor.

Quantizes a floating point tensor to packed format.

Run activation @ dequantize(qw) using mx::quantized_matmul.

Resets the peak memory counter to zero.

Sets the cache limit in bytes. Returns the previous limit.

Sets the memory limit in bytes. Returns the previous limit.

Unlinks a POSIX shared-memory segment by its handle name.

Returns {address, byte_size} for the tensor's raw GPU buffer.

Copies tensor data into a new POSIX shared-memory segment and returns {shm_name, byte_size}.

Converts an EMLX device ref back to an Nx.Tensor.

Functions

abs(tensor)

acos(tensor)

acosh(tensor)

add(tensorA, tensorB)

all(tensor, axes, keep_axes)

allclose(tensorA, tensorB, rtol, atol, equal_nan)

any(tensor, axes, keep_axes)

arange(start, stop, step, integer?, device)

argmax(tensor, keep_axes)

argmax(tensor, axes, keep_axes)

argmin(tensor, keep_axes)

argmin(tensor, axes, keep_axes)

argsort(tensor, axis)

as_strided(tensor, shape, strides, offset)

asin(tensor)

asinh(tensor)

astype(tensor, type)

atan2(tensorA, tensorB)

atan(tensor)

atanh(tensor)

bitwise_and(tensorA, tensorB)

bitwise_not(tensor)

bitwise_or(tensorA, tensorB)

bitwise_xor(tensorA, tensorB)

broadcast_to(tensor, shape)

ceil(tensor)

clear_cache()

@spec clear_cache() :: :ok

Clears the MLX memory cache, releasing unused GPU memory back to the system.

Useful after inference batches to prevent memory growth. Does not affect tensors that are still referenced.

Examples

EMLX.clear_cache()
#=> :ok

clip(tensor, tensor_min, tensor_max)

concatenate(tensors, axis)

conjugate(tensor)

conv_general(tensor_input, tensor_kernel, strides, padding_low, padding_high, kernel_dilation, input_dilation, feature_group_count)

cos(tensor)

cosh(tensor)

cumulative_max(tensor, axis, reverse, inclusive)

cumulative_min(tensor, axis, reverse, inclusive)

cumulative_product(tensor, axis, reverse, inclusive)

cumulative_sum(tensor, axis, reverse, inclusive)

deallocate(tensor_ref)

default_device()

Returns the default MLX device for this process.

Reads :default_device from the :emlx application environment, falling back to :gpu. Override in tests or config via:

Application.put_env(:emlx, :default_device, :cpu)

dequantize(qw)

@spec dequantize(Nx.Tensor.t()) :: Nx.Tensor.t()

Dequantize a quantized Nx.Tensor (created by EMLX.quantize/2) to a dense float tensor by calling mx::dequantize.

dequantize(tensor_w, tensor_scales, tensor_biases, group_size, bits)

Dequantizes packed weights to floating point.

Converts quantized weights back to their original floating point representation. Useful for debugging and verification.

Parameters

  • w - Quantized weights as uint32 (packed int4 values)
  • scales - Per-group scale factors
  • biases - Per-group zero points
  • group_size - Number of weights per group (default: 64)
  • bits - Quantization bits (default: 4)

divide(tensorA, tensorB)

einsum(tensorA, tensorB, spec_string)

equal(tensorA, tensorB)

erf(tensor)

erf_inv(tensor)

eval(arg)

Evaluates a (possibly lazy) MLX tensor by routing the work through an EMLX.CommandQueue. Blocks the caller until the worker thread has finished mlx::core::eval/1 for this tensor.

Resolves the queue via resolve_worker/1:

  1. If the calling process has bound a queue with EMLX.CommandQueue.with_queue/2, that queue is used.
  2. Otherwise the application-default worker for the tensor's device (CPU or GPU) is used — see EMLX.Application.

exp(tensor)

expm1(tensor)

eye(m, n, type, device)

fast_layer_norm(arg1, arg2, arg3, eps)

Fused layer normalisation (mlx::fast::layer_norm).

  • x — input tensor; normalised over the last axis.
  • weight{hidden} scale vector (gamma).
  • bias{hidden} bias vector (beta).
  • eps — numerical stability constant (e.g. 1.0e-5).

Prefer EMLX.Fast.layer_norm/4 inside defn.

fast_layer_norm_no_bias(arg1, arg2, eps)

Fused layer normalisation without bias (mlx::fast::layer_norm, weight-only variant).

Prefer EMLX.Fast.layer_norm/3 inside defn.

fast_rms_norm(arg1, arg2, eps)

Fused RMS normalisation (mlx::fast::rms_norm).

Single Metal shader. Normalises over the last axis of x and scales by weight. Output shape and type match x.

Prefer EMLX.Fast.rms_norm/3 inside defn; call this directly only from eager (non-defn) code.

fast_rope(arg, dims, traditional, base, scale, offset)

Fused rotary position embedding (mlx::fast::rope).

Single Metal shader. Applies RoPE with a scalar position offset.

  • a — input {B, ..., T, D}
  • dims — number of feature dims to rotate (≤ last-axis size, must be even)
  • traditionalfalse for split-half (Qwen3); true for interleaved
  • base — angular frequency base (e.g. 10_000 or 1_000_000)
  • scale — position scale (1.0 unless using NTK-aware scaling)
  • offset — integer token position (length of KV cache already filled)

Prefer EMLX.Fast.rope/6 inside defn.

fast_rope_ids(arg1, dims, traditional, base, scale, arg2)

Fused RoPE with a per-batch offset array (mlx::fast::rope, array-offset overload).

Calls the const array& offset overload of mlx::fast::rope, where offset has shape {B} — one starting position per batch example. Positions within each example are assumed to be sequential: [offset[b], offset[b]+1, ..., offset[b]+T-1].

Typically you build offset by slicing position_ids[:, 0] (first token's position for each batch example) before calling this function.

Prefer EMLX.Fast.rope_with_positions/6 inside defn.

fast_rope_positions(arg1, dims, traditional, base, scale, arg2)

Fused RoPE for arbitrary per-token position_ids.

Uses full {B, T} position IDs (not just an offset) and mirrors Bumblebee's per-token cos/sin lookup, avoiding the sequential-offset assumption of fast_rope_ids.

  • a — input tensor {B, T, H, D}.
  • dims — number of feature dims to rotate.
  • traditional — currently only false is supported.
  • base — angular frequency base (e.g. 1_000_000).
  • scale — position scale.
  • position_ids{B, T} integer tensor.

fast_rope_with_freqs(arg1, dims, traditional, scale, arg2, arg3)

Fused RoPE with precomputed inv-frequency vector (mlx::fast::rope, freqs overload).

  • a — input tensor; shape {B, T, ..., D}.
  • dims — number of feature dims to rotate.
  • traditionalfalse for split-half (Qwen3/Bumblebee); true for interleaved.
  • scale — position scale (typically 1.0 when using precomputed freqs).
  • offset{B} per-batch starting position tensor.
  • freqs{dims/2} precomputed inverse-frequency tensor.

Prefer EMLX.Fast.rope_with_freqs/6 inside defn.

fast_sdpa(arg1, arg2, arg3, scale)

Flash-attention style SDPA, no mask (mlx::fast::scaled_dot_product_attention).

GQA-native: k/v may have fewer heads than q — no pre-tiling needed.

  • q{B, N_q, T_q, D}
  • k{B, N_kv, T_kv, D}
  • v{B, N_kv, T_kv, D}
  • scale — scalar (typically 1 / sqrt(D))

Prefer EMLX.Fast.scaled_dot_product_attention/4 inside defn.

fast_sdpa_causal(arg1, arg2, arg3, scale)

Flash-attention SDPA with built-in causal mask (mlx::fast::scaled_dot_product_attention, mask_mode="causal").

MLX constructs the upper-triangular causal mask internally — no explicit mask tensor required. GQA-native: k/v may have fewer heads than q.

  • q{B, N_q, T_q, D}
  • k{B, N_kv, T_kv, D}
  • v{B, N_kv, T_kv, D}
  • scale — scalar (typically 1 / sqrt(D))

Prefer EMLX.Fast.scaled_dot_product_attention_causal/4 inside defn.

fast_sdpa_causal_key_masked(arg1, arg2, arg3, scale, arg4, kv_offset)

Causal SDPA with a runtime key_mask check performed at the C++ level.

At the NIF level: evaluates all(key_mask == 1) (cheap for small/constant tensors). If true → uses the pure causal Metal kernel (no mask allocation). If false → builds a combined causal + key_mask additive float mask and calls the masked kernel.

  • q{B, N_q, T_q, D}
  • k{B, N_kv, T_kv, D}
  • v{B, N_kv, T_kv, D}
  • scale — scalar (typically 1 / sqrt(D))
  • key_mask{B, T_kv} boolean/int tensor (1 = attend, 0 = padding)

Prefer EMLX.Fast.scaled_dot_product_attention_causal_key_masked/5 inside defn.

fast_sdpa_masked(arg1, arg2, arg3, arg4, scale)

Flash-attention SDPA with an additive or boolean mask.

mask must be broadcast-compatible with {B, N_q, T_q, T_kv}. Boolean false entries are treated as -∞.

Prefer EMLX.Fast.scaled_dot_product_attention/5 inside defn.

fast_swiglu(arg1, arg2)

Fused SwiGLU activation: silu(gate) * up where silu(x) = x * sigmoid(x).

  • gate — gate tensor; silu is applied element-wise.
  • up — up-projection tensor; same shape as gate.

Output has the same shape and dtype as gate.

Prefer EMLX.Fast.swiglu/2 inside defn.

fft2(tensor, s, axes)

fft(tensor, n, axis)

floor(tensor)

from_blob(blob, shape, type, device)

full(value, shape, type, device)

gather(tensor, indices, axes, slice_sizes)

greater(tensorA, tensorB)

greater_equal(tensorA, tensorB)

ifft2(tensor, s, axes)

ifft(tensor, n, axis)

imag(tensor)

is_infinity(tensor)

is_nan(tensor)

is_tensor(device, ref)

(macro)

isclose(tensorA, tensorB, rtol, atol, equal_nan)

item(arg)

Returns the scalar value of a 0-d tensor as a number.

Worker-routed: the NIF body calls mlx::core::eval(*t) and t->item<T>(), both of which require running on the OS thread that owns the tensor's stream encoder.

kv_cache_attention(arg1, arg2, arg3, arg4, arg5, offset, scale)

Fused KV cache update + variable-length SDPA in a single Metal command buffer.

Receives tensors in Bumblebee {B, T, N, D} convention. Internally transposes to MLX {B, N, T, D} for mlx::fast::scaled_dot_product_attention, then transposes the result back. Returns a 3-tuple of EMLX {device, ref} pairs.

  • q{B, T_q, N_q, D} post-RoPE query
  • new_k{B, T_new, N_kv, D} current key projection (post-RoPE)
  • new_v{B, T_new, N_kv, D} current value projection
  • k_cache{B, T_max, N_kv, D} preallocated key buffer
  • v_cache{B, T_max, N_kv, D} preallocated value buffer
  • offset — integer, number of positions already in cache
  • scale — float, 1 / sqrt(head_dim)

Returns {{dev, attn_ref}, {dev, k_upd_ref}, {dev, v_upd_ref}}.

kv_cache_attention_masked(arg1, arg2, arg3, arg4, arg5, offset, scale, arg6)

Like kv_cache_attention/7 but also applies key_mask to exclude padding positions from attention in both prefill and decode steps.

  • key_mask{B, T_kv} integer or boolean tensor with 1 = attend, 0 = skip (padding). Must cover exactly valid_len = offset + T_new positions.

The combined additive mask applies causal AND key_mask constraints without calling mlx::core::all(), avoiding Metal sort-kernel compilation issues for small tensor shapes.

Returns {{dev, attn_ref}, {dev, k_upd_ref}, {dev, v_upd_ref}}.

kv_cache_sdpa_update(arg1, arg2, arg3, arg4, arg5, offset, scale)

Fused KV cache update + SDPA for the native NIF loop (BNHD layout).

Accepts q, new_k, new_v already transposed to {B, N, T, D} and a pre-allocated cache in {B, N_kv, T_max, D} layout.

Internally, the cache arrays are move-extracted from their ENIF resources before slice_update so that MLX's donation optimisation fires at eval time: the existing Metal buffer is reused in-place — no new allocation.

Returns {{dev, attn_ref}, {dev, k_upd_ref}, {dev, v_upd_ref}}.

left_shift(tensorA, tensorB)

less(tensorA, tensorB)

less_equal(tensorA, tensorB)

linalg_cholesky(tensor, upper)

linalg_eigh(tensor, uplo)

linalg_inv(tensor)

linalg_lu(tensor)

linalg_pinv(tensor)

linalg_qr(tensor)

linalg_solve(tensorA, tensorB)

linalg_solve_triangular(tensorA, tensorB, upper)

linalg_svd(tensor, compute_uv)

log1p(tensor)

log(tensor)

logical_and(tensorA, tensorB)

logical_not(tensor)

logical_or(tensorA, tensorB)

logical_xor(tensorA, tensorB)

max(tensor, axes, keep_axes)

maximum(tensorA, tensorB)

memory_info()

Returns a map with current memory usage information.

Keys:

  • :active_memory - bytes currently allocated and in use
  • :peak_memory - highest active memory since last reset
  • :cache_memory - bytes in the allocator cache (freed but not returned to OS)

Examples

iex> info = EMLX.memory_info()
iex> is_integer(info.active_memory) and is_integer(info.peak_memory) and is_integer(info.cache_memory)
true

min(tensor, axes, keep_axes)

minimum(tensorA, tensorB)

multiply(tensorA, tensorB)

negate(tensor)

not_equal(tensorA, tensorB)

ones(shape, type, device)

pad(tensor, axes, low_pad_size, high_pad_size, tensor_pad_value)

pow(tensorA, tensorB)

product(tensor, axes, keep_axes)

quantize(tensor, opts)

@spec quantize(
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Quantize a dense 2-D Nx.Tensor and return an annotated quantized tensor.

The returned tensor carries the original logical shape and type (e.g. {:s, 4}). Its backend stores the packed uint32 data and a EMLX.Quantization.Config with scales, biases, group_size, and bits.

Options

  • :type — storage type: {:s, 2}, {:s, 4} (default), or {:s, 8}.
  • :group_size — 32, 64, or 128 (default 64). Must evenly divide the last dimension of tensor.

quantize(arg, group_size, bits)

Quantizes a floating point tensor to packed format.

Returns a tuple of {quantized_weights, scales, biases} where:

  • quantized_weights - Packed uint32 tensor (8 int4 values per uint32)
  • scales - Per-group scale factors
  • biases - Per-group zero points

Parameters

  • w - Float tensor to quantize
  • group_size - Number of weights per group (default: 64)
  • bits - Quantization bits (default: 4)

quantized_matmul(activation, qw)

@spec quantized_matmul(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Run activation @ dequantize(qw) using mx::quantized_matmul.

qw must be a quantized tensor produced by EMLX.quantize/2. Raises ArgumentError if both arguments are quantized.

quantized_matmul(tensor_x, tensor_w, tensor_scales, tensor_biases, transpose \\ true, group_size \\ 64, bits \\ 4)

Performs quantized matrix multiplication.

This is the key operation for efficient 4-bit inference. It multiplies x with quantized weights w (packed as uint32), using scales and biases for dequantization during the computation.

Parameters

  • x - Input tensor (e.g., {batch, seq, hidden})
  • w - Quantized weights as uint32 (8 int4 values packed per uint32)
  • scales - Per-group scale factors (bfloat16)
  • biases - Per-group zero points (bfloat16)
  • transpose - Whether to transpose weights (default: true)
  • group_size - Number of weights per scale/bias group (default: 64)
  • bits - Quantization bits (default: 4)

quotient(tensorA, tensorB)

real(tensor)

remainder(tensorA, tensorB)

reset_peak_memory()

@spec reset_peak_memory() :: :ok

Resets the peak memory counter to zero.

Examples

EMLX.reset_peak_memory()
#=> :ok

reshape(tensor, shape)

right_shift(tensorA, tensorB)

round(tensor)

rsqrt(tensor)

scalar_tensor(scalar, type, device)

scalar_type(tensor)

scatter(tensor, indices, tensor_updates, axes)

scatter_add(tensor, indices, tensor_updates, axes)

set_cache_limit(limit)

Sets the cache limit in bytes. Returns the previous limit.

When cached memory exceeds this limit, it will be reclaimed on the next allocation. Set to 0 to disable caching entirely.

Examples

prev = EMLX.set_cache_limit(500_000_000)
EMLX.set_cache_limit(prev)

set_memory_limit(limit)

Sets the memory limit in bytes. Returns the previous limit.

The memory limit is a guideline for maximum memory usage during graph evaluation. Defaults to 1.5× the device's recommended working set size.

Examples

prev = EMLX.set_memory_limit(8_000_000_000)
EMLX.set_memory_limit(prev)

shape(tensor)

shm_unlink(name)

Unlinks a POSIX shared-memory segment by its handle name.

Call this if the receiver never opens the %Nx.Pointer{kind: :ipc} returned by Nx.to_pointer/2 — otherwise the shm name persists until the next reboot. Safe to call even if the segment has already been unlinked (ENOENT is ignored).

sigmoid(tensor)

sign(tensor)

sin(tensor)

sinh(tensor)

slice(tensor, starts, stops, strides)

slice_update(tensor, tensor_updates, starts, stops)

sort(tensor, axis)

sqrt(tensor)

squeeze(tensor, axes)

stack(tensors, axis)

strides(tensor)

subtract(tensorA, tensorB)

sum(tensor, axes, keep_axes)

take(tensor, tensorIndices, axis)

take_along_axis(tensor, tensorIndices, axis)

tan(tensor)

tanh(tensor)

tensor_data_ptr(tensor)

Returns {address, byte_size} for the tensor's raw GPU buffer.

Evals the tensor first (same pattern as to_blob/1). The pointer is valid as long as no further MLX evaluation is triggered on the array and the Elixir tensor term is kept alive. On Apple Silicon the address is accessible from both CPU and GPU due to unified memory.

tensor_to_shm(tensor, permissions)

Copies tensor data into a new POSIX shared-memory segment and returns {shm_name, byte_size}.

Note: this involves a memcpy — MLX arrays are immutable so zero-copy cross-process sharing is not possible. permissions is a Unix mode integer (e.g. 0o400 for owner-read-only).

The shm name persists until the receiver opens and unlinks it (which EMLX.NIF.array_from_shm/4 does automatically).

tensordot(tensorA, tensorB, axesA, axesB)

to_blob(tensor)

to_blob(tensor, limit)

to_nx(device_ref)

Converts an EMLX device ref back to an Nx.Tensor.

Example

result_ref = EMLX.some_operation(input)
result_tensor = EMLX.to_nx(result_ref)

transpose(tensor, axes)

tri_inv(tensor, upper)

view(tensor, type)

where(tensorPred, tensorTrue, tensorFalse)

window_scatter_max(tensor_t, tensor_source, tensor_init_value, window, low_pad, high_pad, strides)

window_scatter_min(tensor_t, tensor_source, tensor_init_value, window, low_pad, high_pad, strides)