Nasty.Statistics.Neural.Model behaviour (Nasty v0.3.0)
View SourceBehaviour for neural network models using Axon.
Extends Nasty.Statistics.Model with neural-specific callbacks for
architecture definition, tensor handling, and efficient inference.
Model Lifecycle
- Architecture Definition: Define the Axon model structure
- Training: Train on labeled data with backpropagation
- Serialization: Save model parameters and metadata
- Loading: Restore model from disk
- Inference: Predict on new data with efficient batching
Example
defmodule MyNeuralTagger do
@behaviour Nasty.Statistics.Neural.Model
@impl true
def model_architecture(opts) do
vocab_size = Keyword.fetch!(opts, :vocab_size)
num_tags = Keyword.fetch!(opts, :num_tags)
Axon.input("tokens", shape: {nil, nil})
|> Axon.embedding(vocab_size, 128)
|> Axon.lstm(256, return_sequences: true)
|> Axon.dense(num_tags)
end
@impl true
def input_shape(_model), do: {nil, nil}
@impl true
def output_shape(model), do: {nil, nil, model.num_tags}
endIntegration with Existing Models
Neural models implement the standard Nasty.Statistics.Model behaviour,
so they can be used interchangeably with HMM and other statistical models.
Summary
Callbacks
Returns the expected input shape for the model.
Returns the Axon model architecture.
Returns the expected output shape for the model.
Post-processes model output into predictions.
Prepares input data for model inference.
Functions
Validates that a module correctly implements the Neural.Model behaviour.
Callbacks
Returns the expected input shape for the model.
Shapes use nil for dynamic dimensions (batch size, sequence length).
Examples
iex> input_shape(model)
{nil, nil} # {batch_size, seq_length}
iex> input_shape(model)
{nil, nil, 50} # {batch_size, seq_length, char_length}
Returns the Axon model architecture.
Parameters
opts- Architecture options (vocab_size, num_tags, hidden_size, etc.)
Returns
An %Axon{} struct defining the model architecture.
Examples
iex> model_architecture(vocab_size: 10000, num_tags: 17)
%Axon{...}
Returns the expected output shape for the model.
Examples
iex> output_shape(model)
{nil, nil, 17} # {batch_size, seq_length, num_tags}
@callback postprocess_output( model :: struct(), output :: Nx.Tensor.t(), input :: term(), opts :: keyword() ) :: {:ok, term()} | {:error, term()}
Post-processes model output into predictions.
Converts raw model output tensors (logits, probabilities) into structured predictions (tags, labels, etc.).
Parameters
model- The trained modeloutput- Raw model output (tensor)input- Original input (for alignment)opts- Post-processing options
Returns
{:ok, predictions}- Structured predictions{:error, reason}- Post-processing error
Examples
iex> postprocess_output(model, logits_tensor, ["The", "cat"], [])
{:ok, [:det, :noun]}
@callback prepare_input(model :: struct(), input :: term(), opts :: keyword()) :: {:ok, map()} | {:error, term()}
Prepares input data for model inference.
Converts raw input (tokens, text, etc.) into tensors suitable for the neural network. Handles padding, vocabulary mapping, and batching.
Parameters
model- The trained modelinput- Raw input data (list of words, tokens, etc.)opts- Preprocessing options
Returns
{:ok, tensors}- Map of input tensors keyed by input name{:error, reason}- Preprocessing error
Examples
iex> prepare_input(model, ["The", "cat", "sat"], [])
{:ok, %{"tokens" => #Nx.Tensor<s64[1][3]>}}