ACT: Action Chunking with Transformers for robot imitation learning.
ACT predicts entire action chunks (sequences of future actions) rather than single timestep actions. This drastically reduces compounding error in imitation learning: small mistakes don't propagate across every timestep. A Conditional VAE (CVAE) encoder captures task variation (e.g. "pick from left" vs "pick from right"), and a Transformer decoder autoregressively generates the action chunk conditioned on image observations and the latent style variable z.
Motivation
Standard behavior cloning predicts a_t from o_t at each timestep. Errors
compound over the trajectory. ACT instead predicts [a_t, a_{t+1}, ..., a_{t+chunk_size-1}] in one shot, executing only the first few actions
before re-planning. The CVAE latent z captures multimodal action
distributions (different valid ways to perform a task).
Architecture
Training:
obs [batch, obs_dim] + actions [batch, chunk_size, action_dim]
| |
+----------+-------------------+
|
CVAE Encoder (MLP)
|
mu, log_var
|
reparameterize -> z [batch, latent_dim]
|
+----------+-------------------+
| |
obs [batch, obs_dim] z
| |
+--------> Transformer Decoder (cross-attn on obs, autoregressive)
|
pred_actions [batch, chunk_size, action_dim]
Inference:
z ~ N(0, I) (no encoder needed)
obs -> Transformer Decoder -> action chunkUsage
{encoder, decoder} = ACT.build(
obs_dim: 512,
action_dim: 7,
chunk_size: 100,
latent_dim: 32
)
# Training: encode actions to get latent, decode to reconstruct
%{mu: mu, log_var: log_var} = encoder_predict.(encoder_params, %{"obs" => obs, "actions" => actions})
{z, key} = ACT.reparameterize(mu, log_var, key)
pred = decoder_predict.(decoder_params, %{"obs" => obs, "z" => z})
# Inference: sample z from prior
z = Nx.Random.normal(key, shape: {batch, latent_dim})
pred = decoder_predict.(decoder_params, %{"obs" => obs, "z" => z})
# Loss
loss = ACT.act_loss(pred_actions, target_actions, mu, log_var)References
- Zhao et al., "Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware" (2023) — https://arxiv.org/abs/2304.13705
- ALOHA project: https://tonyzhaozh.github.io/aloha/
Summary
Functions
ACT loss: MSE reconstruction + beta * KL divergence.
Build the ACT model (CVAE encoder + Transformer decoder).
Build the Transformer decoder.
Build the CVAE encoder.
Transformer decoder forward: (obs, z) -> action_chunk.
CVAE encoder forward pass: observation + actions -> (mu, log_var).
Get output size (action_dim * chunk_size flattened, or action_dim per step).
Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).
Types
@type build_opt() :: {:obs_dim, pos_integer()} | {:action_dim, pos_integer()} | {:chunk_size, pos_integer()} | {:hidden_dim, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:latent_dim, pos_integer()} | {:dropout, float()}
Options for build/1.
Functions
@spec act_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
ACT loss: MSE reconstruction + beta * KL divergence.
Parameters
pred_actions- Predicted action chunk[batch, chunk_size, action_dim]target_actions- Ground truth actions[batch, chunk_size, action_dim]mu- Encoder mean[batch, latent_dim]log_var- Encoder log variance[batch, latent_dim]
Options
:beta- KL weight (default: 1.0)
Returns
Scalar loss tensor.
Build the ACT model (CVAE encoder + Transformer decoder).
Options
:obs_dim- Observation feature dimension (required):action_dim- Action dimension per timestep (required):chunk_size- Number of future actions to predict (default: 100):hidden_dim- Transformer hidden dimension (default: 256):num_heads- Number of attention heads (default: 8):num_layers- Number of decoder layers (default: 6):latent_dim- CVAE latent dimension (default: 32):dropout- Dropout rate (default: 0.1)
Returns
{encoder, decoder} tuple:
- Encoder: inputs
"obs"and"actions"->%{mu: ..., log_var: ...} - Decoder: inputs
"obs"and"z"->[batch, chunk_size, action_dim]
Build the Transformer decoder.
Takes observation features and latent z, outputs action chunk.
Options
Same as build/1.
Returns
Axon model with inputs "obs" [batch, obs_dim] and "z"
[batch, latent_dim], outputting [batch, chunk_size, action_dim].
Build the CVAE encoder.
Maps (observation, action_sequence) to a latent distribution (mu, log_var).
Options
Same as build/1.
Returns
Axon model with inputs "obs" [batch, obs_dim] and "actions"
[batch, chunk_size, action_dim], outputting %{mu: ..., log_var: ...}.
Transformer decoder forward: (obs, z) -> action_chunk.
Convenience wrapper that creates an Axon subgraph for decoding.
Parameters
obs- Axon node for observation[batch, obs_dim]z- Axon node for latent[batch, latent_dim]opts- Same options asbuild/1
Returns
Axon node outputting [batch, chunk_size, action_dim].
CVAE encoder forward pass: observation + actions -> (mu, log_var).
Convenience wrapper that builds and runs the encoder.
Parameters
obs- Observation tensor[batch, obs_dim]actions- Action sequence tensor[batch, chunk_size, action_dim]opts- Same options asbuild/1
Returns
{mu, log_var} tuple of tensors, each [batch, latent_dim].
@spec output_size(keyword()) :: pos_integer()
Get output size (action_dim * chunk_size flattened, or action_dim per step).
@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).
Computes z = mu + eps * exp(0.5 * log_var) where eps ~ N(0, I).
Parameters
mu- Mean[batch, latent_dim]log_var- Log variance[batch, latent_dim]key- PRNG key
Returns
{z, new_key} tuple.