View Source Axon -> PyTorch

This cheatsheet is designed to assist PyTorch developers in transitioning to Elixir and Axon, providing equivalent commands and code examples for common neural network tasks.

Core Paradigm: Functional vs. Object-Oriented

A key difference between Axon and PyTorch lies in their core design paradigms:

Axon (Functional)

Axon follows a functional approach, inspired by libraries like JAX. Models are defined as compositions of functions that transform input data and parameters into output data. State (like model parameters) is managed explicitly and passed into functions. This promotes purity, explicit data flow, and composability, often leveraging Just-In-Time (JIT) compilation via Nx backends and compilers (like EXLA) for performance.

PyTorch (Object-Oriented)

PyTorch uses an object-oriented approach. Models are typically defined as classes inheriting from torch.nn.Module. These classes encapsulate layers and parameters as internal state. The forward pass is defined as a method (forward) operating on this internal state. This provides a familiar structure for many developers but can sometimes obscure data flow and state management compared to the functional style.

This cheatsheet will highlight how common tasks are achieved in both paradigms.

Model Definition

Sequential Models

Axon

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

PyTorch

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.Softmax(dim=1)
)

Common Layer Types

Dense / Linear

Applies a linear transformation to the incoming data: y = xW^T + b.

Axon

input = Axon.input("features")
dense_layer = Axon.dense(input, out_features, activation: :relu, name: "my_dense_layer")
dense_layer = Axon.dense(input, 128)

#### PyTorch

dense_layer = nn.Linear(in_features=784, out_features=128) relu = nn.ReLU() output = relu(dense_layer(x))


### Convolutional (Conv2D)

Applies a 2D convolution over an input signal composed of several input planes.

#### Axon

Example: 32 filters, 3x3 kernel, ReLU activation

x = Axon.input("features") Axon.conv(x, 32, kernel_size: 3, activation: :relu, padding: :same, name: "conv1")

Stride, padding, etc., are options:

Axon.conv(x, 64, kernel_size: {3, 3}, strides: 2, padding: :valid)


*Note: Axon typically uses NHWC (Batch, Height, Width, Channels) format by default, common in TensorFlow/Keras.*

#### PyTorch

Example: 32 filters, 3x3 kernel, ReLU activation

conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding='same') relu = nn.ReLU() output = relu(conv1(x))

Stride, padding, etc., are arguments:

conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=0) # padding=0 is 'valid'


*Note: PyTorch uses NCHW (Batch, Channels, Height, Width) format by default.*

### Pooling (MaxPool2D)

Applies 2D max pooling over an input signal.

#### Axon

Example: 2x2 pool size, stride 2

Axon.max_pool(previous_layer, kernel_size: 2, strides: 2, name: "pool1")

Padding can also be specified (default is :valid)

Axon.max_pool(previous_layer, kernel_size: {3, 3}, strides: 1, padding: :same)


*Note: Operates on NHWC format by default.*

#### PyTorch

Applies 2D max pooling over an input signal.

import torch.nn as nn

Example: 2x2 pool size, stride 2

pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

Padding can also be specified (default is 0 / 'valid')

To achieve 'same' padding, calculation might be needed or use ceil_mode=True

pool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) # padding=1 for 3x3 kernel approximates 'same'

Example usage:

Assuming input tensor x with shape (N, C, H, W)

output = pool1(x)


*Note: Operates on NCHW format by default.*

### Dropout

Randomly zeroes some elements of the input tensor with probability `p` during training. This is a regularization technique.

#### Axon

Applies dropout during training (`mode: :train`). It's a no-op during inference (`mode: :infer`).

Rate is the probability of an element being zeroed.

Axon.dropout(previous_layer, rate: 0.5, name: "dropout1")

Usage is implicit within the model's structure.

The mode (:train or :infer) is passed to the model execution function.

{init_fn, predict_fn} = Axon.build(model, mode: :train)

predict_fn.(params, inputs, mode: :train)


#### PyTorch

Applies dropout during training (`model.train()` mode). It's a no-op during evaluation (`model.eval()` mode).

p is the probability of an element being zeroed.

dropout1 = nn.Dropout(p=0.5)

Example usage:

model.train() # Set model to training mode

output = dropout1(x)

model.eval() # Set model to evaluation mode (dropout becomes identity)

output_eval = dropout1(x) # dropout1 has no effect here


### Normalization (LayerNorm)

Applies Layer Normalization over a mini-batch of inputs.
Normalizes the activations of the previous layer for each given example independently.

#### Axon

Typically applied to the feature dimension(s).

Axon.layer_norm(previous_layer, name: "layernorm1")

Can specify the axis/axes for normalization (default is usually the last axis)

Axon.layer_norm(previous_layer, axis: -1, epsilon: 1.0e-5)


#### PyTorch

Provide the shape of the features to normalize over.

This typically means the last dimension(s) of the tensor.

Example 1: Input (N, features_dim), normalize over features_dim

features_dim = 128

layernorm1 = nn.LayerNorm(features_dim)

Example 2: Input (N, C, H, W), normalize over C, H, W:

normalized_shape = [C, H, W] # Needs actual channel, height, width values

layernorm2 = nn.LayerNorm(normalized_shape)

Example 3: Common case in Transformers (Input: N, SeqLen, EmbedDim):

embed_dim = 512 layernorm_transformers = nn.LayerNorm(embed_dim)

Example usage (assuming input x with shape (N, SeqLen, embed_dim)):

output = layernorm_transformers(x)


## Activation Functions

#### Axon

Activations are typically specified as options within layers (like `Axon.dense`) or applied as separate layers in the model definition pipeline.

Option 1: As layer option

model = Axon.input("input", shape: {nil, 784})

    |> Axon.dense(128, activation: :relu)

Option 2: As separate layer

model = Axon.input("input", shape: {nil, 10})

    |> Axon.dense(128)
    |> Axon.softmax()

Common activation atoms: :relu, :softmax, :sigmoid, :tanh, :identity, etc.

Custom functions can also be used.

Axon.activation(layer, name) can also be used with the function name atoms.


#### PyTorch

PyTorch also supports a variety of activation functions, including built-in ones and custom implementations.

relu = nn.ReLU() softmax = nn.Softmax(dim=1) sigmoid = nn.Sigmoid() tanh = nn.Tanh()

output = relu(x) output = softmax(x) output = sigmoid(x) output = tanh(x)


## Defining Custom Layers/Models

#### Axon

Axon allows for the definition of custom layers and models.
`Axon.block/1` as shown below allows us to reuse the same parameters
for an arbitrary Axon subgraph.

This means that the difference between the 2 examples below is that
while the first has separate weights for the first and second dense layers,
the second example uses the same weights for both.

Example:

defmodule MyCustomLayers do def dense(x) do

Axon.dense(x, 128, activation: :relu, name: "my_dense_layer")

end

def block do

Axon.block(&dense/1)

end end

Usage:

input = Axon.input("input", shape: {nil, 784}) model = input |> MyCustomLayers.dense() |> MyCustomLayers.dense()

dense_block = MyCustomLayers.block() model = input |> then(dense_block) |> then(dense_block)


#### PyTorch

PyTorch allows for the definition of custom layers and models.

Example:

class MyCustomLayer(nn.Module):

def __init__(self):
    super(MyCustomLayer, self).__init__()
    self.dense = nn.Linear(784, 128)
    self.relu = nn.ReLU()

def forward(self, x):
    return self.relu(self.dense(x))

Usage:

model = MyCustomLayer()


## Model Initialization

Initialization refers to creating the initial set of parameters (weights, biases) for the model.

#### Axon

Model definition is separate from initialization. `Axon.build/2` compiles the model definition and returns an initialization function (`init_fn`) and a prediction function (`predict_fn`).

1. Define the model

model = Axon.input("input", shape: {nil, 784})

    |> Axon.dense(128, activation: :relu)

2. Build the model to get init_fn

{init_fn, _predict_fn} = Axon.build(model)

3. Initialize parameters using an input shape template and a map with optional values for the parameters.

The second argument is useful when loading saved parameters.

input_template = Nx.template({1, 784}, :f32) params = init_fn.(input_template, %{})

params now holds the initialized parameters (e.g., a nested map)


#### PyTorch

In PyTorch, basic parameter initialization happens when the model class (an `nn.Module`) is instantiated.
Layers like `nn.Linear` have default initialization schemes (often Kaiming uniform for weights).

import torch import torch.nn as nn

1. Define the model class (or use nn.Sequential)

class SimpleModel(nn.Module):

def __init__(self):
    super().__init__()
    self.layer1 = nn.Linear(784, 128)
    self.relu = nn.ReLU()

def forward(self, x):
    return self.relu(self.layer1(x))

2. Instantiate the model - parameters are initialized here

model = SimpleModel()

model.parameters() now holds tensors with initial values

Explicit initialization can be done after instantiation if needed

def init_weights(m):

if isinstance(m, nn.Linear):

torch.nn.init.xavieruniform(m.weight)

m.bias.data.fill_(0.01)

model.apply(init_weights)


## Forward Pass / Prediction

#### Axon

Axon's forward pass is defined by the composition of functions.

Example:

model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(128, activation: :relu)

The actual forward pass happens during the execution of the model.

mode: :inference is passed when not training the model.

{init_fn, predict_fn} = Axon.build(model, mode: :train)

predict_fn.(params, inputs)


#### PyTorch

PyTorch's forward pass is defined by the `forward` method of the model class.

Example:

model = nn.Sequential(

nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(dim=1)

)

The actual forward pass happens during the execution of the model.

output = model(x)


## Loss Functions

#### Axon

Axon provides loss functions in `Axon.Losses` that take targets and predictions. When using `Axon.Loop`, losses are often specified by atoms.

Manual Calculation (e.g., in evaluation):

targets = # ... predictions = # ... loss_value = Axon.Losses.mean_squared_error(predictions, targets) loss_value = Axon.Losses.categorical_cross_entropy(predictions, targets)

Using Axon.Loop (specify loss by atom):

model_state = Axon.Loop.trainer(model, :mean_squared_error, optimizer) model_state = Axon.Loop.trainer(model, :categorical_cross_entropy, optimizer)


#### PyTorch

PyTorch provides various loss functions.

Example:

import torch import torch.nn as nn

criterion = nn.MSELoss() output = criterion(y_pred, y_true)

criterion = nn.CrossEntropyLoss() output = criterion(y_pred, y_true)


## Optimizers

#### Axon

Optimizers in the Axon ecosystem typically come from the `Polaris` library. They are passed to `Axon.Loop.trainer` or used manually with an update function.

Import the optimizers

import Polaris.Optimizers

Define the optimizer

optimizer = Polaris.Optimizers.sgd(learning_rate: 0.01) optimizer = Polaris.Optimizers.adam(learning_rate: 0.001)

Use with Axon.Loop:

model_state = Axon.Loop.trainer(model, loss_fn, optimizer)

Manual update step (simplified):

{grads, loss} = grad_fn.(params, inputs, targets)

{params, optimizer_state} = Polaris.Optimizers.update(optimizer, grads, params, optimizer_state)


#### PyTorch

PyTorch provides various optimizers.

Example:

import torch import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=0.01) optimizer = optim.Adam(model.parameters(), lr=0.001)


## Basic Training Loop

#### Axon

Axon supports manual training loops but provides the `Axon.Loop` module for convenient, high-level training.

High-level approach using Axon.Loop:

model_state = Axon.Loop.trainer(model, :categorical_cross_entropy, optimizer)

           |> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA)

Manual loop structure (simplified):

model = # ... loss_fn = &Axon.Losses.categorical_cross_entropy/2 optimizer = Polaris.Optimizers.adam(learning_rate: 0.001) {init_fn, predict_fn} = Axon.build(model, compiler: EXLA, mode: :train) params = init_fn.(input_template, key) opt_state = Polaris.Optimizers.init(optimizer, params) for epoch <- 1..10 do Enum.reduce(train_data, {params, opt_state}, fn {inputs, targets}, {params, opt_state} ->

# Define a gradient function
{{_preds, loss}, grads} =
  Nx.Defn.value_and_grad(
    inputs,
    fn inputs ->
      preds = predict_fn.(params, inputs)
      loss = loss_fn.(targets, preds)

      {preds, loss}
    end,
    fn {_preds, loss} -> loss end
  )

{updates, opt_state} = Polaris.Optimizers.update(optimizer, grads, params, opt_state)
params = Polaris.Updates.apply_updates(params, updates)
{params, opt_state}

end) end


#### PyTorch

A standard PyTorch training loop involves iterating through data, zeroing gradients, performing a forward pass, calculating loss, performing a backward pass, and stepping the optimizer.

import torch import torch.nn as nn import torch.optim as optim

Assume: model, train_dataloader, loss_fn, optimizer are defined

model = YourModel()

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

train_dataloader = ...

num_epochs = 10 model.train() # Set model to training mode

for epoch in range(num_epochs):

for batch_idx, (inputs, targets) in enumerate(train_dataloader):
    # inputs, targets = inputs.to(device), targets.to(device) # Optional: move to GPU

    # 1. Zero gradients
    optimizer.zero_grad()

    # 2. Forward pass
    outputs = model(inputs)

    # 3. Calculate loss
    loss = loss_fn(outputs, targets)

    # 4. Backward pass (compute gradients)
    loss.backward()

    # 5. Optimizer step (update weights)
    optimizer.step()

    if batch_idx % 100 == 0: # Print progress
        print(f"Epoch {epoch}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}")

## Model Inspection / Summary

#### Axon

Axon provides helpers in `Axon.Display` to show model summaries, including layer outputs shapes and parameter counts. Requires an input template.

Define model and input template

model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(10) input_template = Nx.template({1, 784}, :f32)

Print summary table

Axon.Display.as_table(model, input_template) |> IO.puts()


#### PyTorch

Printing a PyTorch model shows its layers. For more detailed summaries including output shapes and parameter counts (similar to Keras' `model.summary()`), use external libraries like `torchinfo`.

import torch import torch.nn as nn

Assume model is an instantiated nn.Module

1. Basic layer structure

print(model)

2. Using torchinfo (requires pip install torchinfo)

from torchinfo import summary batch_size = 64 summary(model, input_size=(batch_size, 784)) # Provide input size