Nasty.Statistics.Neural.Inference (Nasty v0.3.0)

View Source

Efficient inference utilities for neural models.

Provides optimized prediction with:

  • Batch processing for multiple inputs
  • Dynamic batching for variable-length sequences
  • Model warmup and JIT compilation
  • Result caching
  • EXLA acceleration

Example

# Single prediction
{:ok, tags} = Inference.predict(model, state, ["The", "cat", "sat"], [])

# Batch prediction
sentences = [
  ["The", "cat", "sat"],
  ["A", "dog", "ran"],
  ["Birds", "fly"]
]
{:ok, all_tags} = Inference.predict_batch(model, state, sentences, [])

Performance Tips

  1. Use batch prediction when possible for better throughput
  2. Enable EXLA compilation for 10-100x speedup
  3. Warm up the model on first use to trigger JIT compilation
  4. Use consistent batch sizes when possible

Summary

Functions

Runs inference on a single input.

Runs inference on a batch of inputs efficiently.

Streams predictions for large datasets.

Warms up a model by running a dummy prediction.

Functions

predict(model, state, input, opts \\ [])

@spec predict(Axon.t(), map(), term(), keyword()) :: {:ok, term()} | {:error, term()}

Runs inference on a single input.

Parameters

  • model - Axon model
  • state - Trained model state (parameters)
  • input - Input data (will be batched automatically)
  • opts - Inference options

Options

  • :compiler - Backend compiler: :exla or :blas (default: :exla)
  • :mode - Execution mode: :train or :inference (default: :inference)

Returns

  • {:ok, output} - Model prediction
  • {:error, reason} - Inference error

predict_batch(model, state, inputs, opts \\ [])

@spec predict_batch(Axon.t(), map(), [map()], keyword()) ::
  {:ok, [term()]} | {:error, term()}

Runs inference on a batch of inputs efficiently.

All inputs in the batch must have the same structure (same keys). For variable-length sequences, padding will be applied automatically.

Parameters

  • model - Axon model
  • state - Trained model state
  • inputs - List of input maps
  • opts - Inference options

Options

  • :batch_size - Process in batches of this size (default: 32)
  • :compiler - Backend compiler (default: :exla)
  • :pad_value - Value to use for padding (default: 0)

Returns

  • {:ok, outputs} - List of predictions (one per input)
  • {:error, reason} - Inference error

stream_predict(model, state, input_stream, opts \\ [])

@spec stream_predict(Axon.t(), map(), Enumerable.t(), keyword()) :: Enumerable.t()

Streams predictions for large datasets.

Processes inputs in batches and yields results as a stream, avoiding loading all results into memory at once.

Parameters

  • model - Axon model
  • state - Trained model state
  • input_stream - Stream of input maps
  • opts - Streaming options

Returns

A stream of predictions.

Example

File.stream!("large_dataset.txt")
|> Stream.map(&prepare_input/1)
|> Inference.stream_predict(model, state, batch_size: 64)
|> Stream.map(&postprocess_output/1)
|> Enum.take(100)

warmup(model, state, sample_input, opts \\ [])

@spec warmup(Axon.t(), map(), map(), keyword()) :: :ok | {:error, term()}

Warms up a model by running a dummy prediction.

This triggers JIT compilation and caches the compiled function, making subsequent predictions faster.

Parameters

  • model - Axon model
  • state - Trained model state
  • sample_input - Sample input with correct shape
  • opts - Warmup options

Returns

  • :ok - Warmup completed
  • {:error, reason} - Warmup failed