View Source Axon -> PyTorch
This cheatsheet is designed to assist PyTorch developers in transitioning to Elixir and Axon, providing equivalent commands and code examples for common neural network tasks.
Core Paradigm: Functional vs. Object-Oriented
A key difference between Axon and PyTorch lies in their core design paradigms:
Axon (Functional)
Axon follows a functional approach, inspired by libraries like JAX. Models are defined as compositions of functions that transform input data and parameters into output data. State (like model parameters) is managed explicitly and passed into functions. This promotes purity, explicit data flow, and composability, often leveraging Just-In-Time (JIT) compilation via Nx backends and compilers (like EXLA) for performance.
PyTorch (Object-Oriented)
PyTorch uses an object-oriented approach.
Models are typically defined as classes inheriting from torch.nn.Module.
These classes encapsulate layers and parameters as internal state.
The forward pass is defined as a method (forward) operating on this internal state.
This provides a familiar structure for many developers but can sometimes obscure
data flow and state management compared to the functional style.
This cheatsheet will highlight how common tasks are achieved in both paradigms.
Model Definition
Sequential Models
Axon
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)PyTorch
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(dim=1)
)Common Layer Types
Dense / Linear
Applies a linear transformation to the incoming data: y = xW^T + b.
Axon
input = Axon.input("features")
dense_layer = Axon.dense(input, out_features, activation: :relu, name: "my_dense_layer")
dense_layer = Axon.dense(input, 128)
#### PyTorchdense_layer = nn.Linear(in_features=784, out_features=128) relu = nn.ReLU() output = relu(dense_layer(x))
### Convolutional (Conv2D)
Applies a 2D convolution over an input signal composed of several input planes.
#### AxonExample: 32 filters, 3x3 kernel, ReLU activation
x = Axon.input("features") Axon.conv(x, 32, kernel_size: 3, activation: :relu, padding: :same, name: "conv1")
Stride, padding, etc., are options:
Axon.conv(x, 64, kernel_size: {3, 3}, strides: 2, padding: :valid)
*Note: Axon typically uses NHWC (Batch, Height, Width, Channels) format by default, common in TensorFlow/Keras.*
#### PyTorch
Example: 32 filters, 3x3 kernel, ReLU activation
conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding='same') relu = nn.ReLU() output = relu(conv1(x))
Stride, padding, etc., are arguments:
conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=0) # padding=0 is 'valid'
*Note: PyTorch uses NCHW (Batch, Channels, Height, Width) format by default.*
### Pooling (MaxPool2D)
Applies 2D max pooling over an input signal.
#### AxonExample: 2x2 pool size, stride 2
Axon.max_pool(previous_layer, kernel_size: 2, strides: 2, name: "pool1")
Padding can also be specified (default is :valid)
Axon.max_pool(previous_layer, kernel_size: {3, 3}, strides: 1, padding: :same)
*Note: Operates on NHWC format by default.*
#### PyTorch
Applies 2D max pooling over an input signal.
import torch.nn as nn
Example: 2x2 pool size, stride 2
pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
Padding can also be specified (default is 0 / 'valid')
To achieve 'same' padding, calculation might be needed or use ceil_mode=True
pool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) # padding=1 for 3x3 kernel approximates 'same'
Example usage:
Assuming input tensor x with shape (N, C, H, W)
output = pool1(x)
*Note: Operates on NCHW format by default.*
### Dropout
Randomly zeroes some elements of the input tensor with probability `p` during training. This is a regularization technique.
#### Axon
Applies dropout during training (`mode: :train`). It's a no-op during inference (`mode: :infer`).
Rate is the probability of an element being zeroed.
Axon.dropout(previous_layer, rate: 0.5, name: "dropout1")
Usage is implicit within the model's structure.
The mode (:train or :infer) is passed to the model execution function.
{init_fn, predict_fn} = Axon.build(model, mode: :train)
predict_fn.(params, inputs, mode: :train)
#### PyTorch
Applies dropout during training (`model.train()` mode). It's a no-op during evaluation (`model.eval()` mode).
p is the probability of an element being zeroed.
dropout1 = nn.Dropout(p=0.5)
Example usage:
model.train() # Set model to training mode
output = dropout1(x)
model.eval() # Set model to evaluation mode (dropout becomes identity)
output_eval = dropout1(x) # dropout1 has no effect here
### Normalization (LayerNorm)
Applies Layer Normalization over a mini-batch of inputs.
Normalizes the activations of the previous layer for each given example independently.
#### AxonTypically applied to the feature dimension(s).
Axon.layer_norm(previous_layer, name: "layernorm1")
Can specify the axis/axes for normalization (default is usually the last axis)
Axon.layer_norm(previous_layer, axis: -1, epsilon: 1.0e-5)
#### PyTorchProvide the shape of the features to normalize over.
This typically means the last dimension(s) of the tensor.
Example 1: Input (N, features_dim), normalize over features_dim
features_dim = 128
layernorm1 = nn.LayerNorm(features_dim)
Example 2: Input (N, C, H, W), normalize over C, H, W:
normalized_shape = [C, H, W] # Needs actual channel, height, width values
layernorm2 = nn.LayerNorm(normalized_shape)
Example 3: Common case in Transformers (Input: N, SeqLen, EmbedDim):
embed_dim = 512 layernorm_transformers = nn.LayerNorm(embed_dim)
Example usage (assuming input x with shape (N, SeqLen, embed_dim)):
output = layernorm_transformers(x)
## Activation Functions
#### Axon
Activations are typically specified as options within layers (like `Axon.dense`) or applied as separate layers in the model definition pipeline.
Option 1: As layer option
model = Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)Option 2: As separate layer
model = Axon.input("input", shape: {nil, 10})
|> Axon.dense(128)
|> Axon.softmax()Common activation atoms: :relu, :softmax, :sigmoid, :tanh, :identity, etc.
Custom functions can also be used.
Axon.activation(layer, name) can also be used with the function name atoms.
#### PyTorch
PyTorch also supports a variety of activation functions, including built-in ones and custom implementations.
relu = nn.ReLU() softmax = nn.Softmax(dim=1) sigmoid = nn.Sigmoid() tanh = nn.Tanh()
output = relu(x) output = softmax(x) output = sigmoid(x) output = tanh(x)
## Defining Custom Layers/Models
#### Axon
Axon allows for the definition of custom layers and models.
`Axon.block/1` as shown below allows us to reuse the same parameters
for an arbitrary Axon subgraph.
This means that the difference between the 2 examples below is that
while the first has separate weights for the first and second dense layers,
the second example uses the same weights for both.
Example:
defmodule MyCustomLayers do def dense(x) do
Axon.dense(x, 128, activation: :relu, name: "my_dense_layer")end
def block do
Axon.block(&dense/1)end end
Usage:
input = Axon.input("input", shape: {nil, 784}) model = input |> MyCustomLayers.dense() |> MyCustomLayers.dense()
dense_block = MyCustomLayers.block() model = input |> then(dense_block) |> then(dense_block)
#### PyTorch
PyTorch allows for the definition of custom layers and models.
Example:
class MyCustomLayer(nn.Module):
def __init__(self):
super(MyCustomLayer, self).__init__()
self.dense = nn.Linear(784, 128)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.dense(x))Usage:
model = MyCustomLayer()
## Model Initialization
Initialization refers to creating the initial set of parameters (weights, biases) for the model.
#### Axon
Model definition is separate from initialization. `Axon.build/2` compiles the model definition and returns an initialization function (`init_fn`) and a prediction function (`predict_fn`).
1. Define the model
model = Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)2. Build the model to get init_fn
{init_fn, _predict_fn} = Axon.build(model)
3. Initialize parameters using an input shape template and a map with optional values for the parameters.
The second argument is useful when loading saved parameters.
input_template = Nx.template({1, 784}, :f32) params = init_fn.(input_template, %{})
params now holds the initialized parameters (e.g., a nested map)
#### PyTorch
In PyTorch, basic parameter initialization happens when the model class (an `nn.Module`) is instantiated.
Layers like `nn.Linear` have default initialization schemes (often Kaiming uniform for weights).
import torch import torch.nn as nn
1. Define the model class (or use nn.Sequential)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.layer1(x))2. Instantiate the model - parameters are initialized here
model = SimpleModel()
model.parameters() now holds tensors with initial values
Explicit initialization can be done after instantiation if needed
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavieruniform(m.weight)
m.bias.data.fill_(0.01)
model.apply(init_weights)
## Forward Pass / Prediction
#### Axon
Axon's forward pass is defined by the composition of functions.
Example:
model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(128, activation: :relu)
The actual forward pass happens during the execution of the model.
mode: :inference is passed when not training the model.
{init_fn, predict_fn} = Axon.build(model, mode: :train)
predict_fn.(params, inputs)
#### PyTorch
PyTorch's forward pass is defined by the `forward` method of the model class.
Example:
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(dim=1))
The actual forward pass happens during the execution of the model.
output = model(x)
## Loss Functions
#### Axon
Axon provides loss functions in `Axon.Losses` that take targets and predictions. When using `Axon.Loop`, losses are often specified by atoms.
Manual Calculation (e.g., in evaluation):
targets = # ... predictions = # ... loss_value = Axon.Losses.mean_squared_error(predictions, targets) loss_value = Axon.Losses.categorical_cross_entropy(predictions, targets)
Using Axon.Loop (specify loss by atom):
model_state = Axon.Loop.trainer(model, :mean_squared_error, optimizer) model_state = Axon.Loop.trainer(model, :categorical_cross_entropy, optimizer)
#### PyTorch
PyTorch provides various loss functions.
Example:
import torch import torch.nn as nn
criterion = nn.MSELoss() output = criterion(y_pred, y_true)
criterion = nn.CrossEntropyLoss() output = criterion(y_pred, y_true)
## Optimizers
#### Axon
Optimizers in the Axon ecosystem typically come from the `Polaris` library. They are passed to `Axon.Loop.trainer` or used manually with an update function.
Import the optimizers
import Polaris.Optimizers
Define the optimizer
optimizer = Polaris.Optimizers.sgd(learning_rate: 0.01) optimizer = Polaris.Optimizers.adam(learning_rate: 0.001)
Use with Axon.Loop:
model_state = Axon.Loop.trainer(model, loss_fn, optimizer)
Manual update step (simplified):
{grads, loss} = grad_fn.(params, inputs, targets)
{params, optimizer_state} = Polaris.Optimizers.update(optimizer, grads, params, optimizer_state)
#### PyTorch
PyTorch provides various optimizers.
Example:
import torch import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01) optimizer = optim.Adam(model.parameters(), lr=0.001)
## Basic Training Loop
#### Axon
Axon supports manual training loops but provides the `Axon.Loop` module for convenient, high-level training.
High-level approach using Axon.Loop:
model_state = Axon.Loop.trainer(model, :categorical_cross_entropy, optimizer)
|> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA)Manual loop structure (simplified):
model = # ... loss_fn = &Axon.Losses.categorical_cross_entropy/2 optimizer = Polaris.Optimizers.adam(learning_rate: 0.001) {init_fn, predict_fn} = Axon.build(model, compiler: EXLA, mode: :train) params = init_fn.(input_template, key) opt_state = Polaris.Optimizers.init(optimizer, params) for epoch <- 1..10 do Enum.reduce(train_data, {params, opt_state}, fn {inputs, targets}, {params, opt_state} ->
# Define a gradient function
{{_preds, loss}, grads} =
Nx.Defn.value_and_grad(
inputs,
fn inputs ->
preds = predict_fn.(params, inputs)
loss = loss_fn.(targets, preds)
{preds, loss}
end,
fn {_preds, loss} -> loss end
)
{updates, opt_state} = Polaris.Optimizers.update(optimizer, grads, params, opt_state)
params = Polaris.Updates.apply_updates(params, updates)
{params, opt_state}end) end
#### PyTorch
A standard PyTorch training loop involves iterating through data, zeroing gradients, performing a forward pass, calculating loss, performing a backward pass, and stepping the optimizer.
import torch import torch.nn as nn import torch.optim as optim
Assume: model, train_dataloader, loss_fn, optimizer are defined
model = YourModel()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_dataloader = ...
num_epochs = 10 model.train() # Set model to training mode
for epoch in range(num_epochs):
for batch_idx, (inputs, targets) in enumerate(train_dataloader):
# inputs, targets = inputs.to(device), targets.to(device) # Optional: move to GPU
# 1. Zero gradients
optimizer.zero_grad()
# 2. Forward pass
outputs = model(inputs)
# 3. Calculate loss
loss = loss_fn(outputs, targets)
# 4. Backward pass (compute gradients)
loss.backward()
# 5. Optimizer step (update weights)
optimizer.step()
if batch_idx % 100 == 0: # Print progress
print(f"Epoch {epoch}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}")
## Model Inspection / Summary
#### Axon
Axon provides helpers in `Axon.Display` to show model summaries, including layer outputs shapes and parameter counts. Requires an input template.
Define model and input template
model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(10) input_template = Nx.template({1, 784}, :f32)
Print summary table
Axon.Display.as_table(model, input_template) |> IO.puts()
#### PyTorch
Printing a PyTorch model shows its layers. For more detailed summaries including output shapes and parameter counts (similar to Keras' `model.summary()`), use external libraries like `torchinfo`.
import torch import torch.nn as nn
Assume model is an instantiated nn.Module
1. Basic layer structure
print(model)
2. Using torchinfo (requires pip install torchinfo)
from torchinfo import summary batch_size = 64 summary(model, input_size=(batch_size, 784)) # Provide input size