World Model — learns a latent dynamics model of an environment.
Encodes observations into a latent space, predicts next-state transitions given actions, and optionally decodes back to observation space. This is the core component for model-based RL and planning.
Components
- Encoder:
obs → z— Maps raw observations to latent state - Dynamics:
(z, action) → next_z— Predicts next latent state - Reward head:
z → scalar— Predicts reward from latent state - Decoder (optional):
z → obs— Reconstructs observations
Dynamics Variants
:mlp— Standard two-layer MLP transition:neural_ode— Shared-weight Euler integration (continuous dynamics):gru— Gated recurrent update (good for partially observable envs)
Architecture
obs [batch, obs_size]
|
+==============+
| Encoder | dense → GELU → dense
+==============+
|
z [batch, latent_size]
|
+-----|-----+
| | |
v v v
Dynamics Reward Decoder (optional)
(z,a)→z' z→r z→obsReturns
{encoder, dynamics, reward_head} or {encoder, dynamics, reward_head, decoder}
when use_decoder: true.
Usage
{encoder, dynamics, reward_head} = WorldModel.build(
obs_size: 64,
action_size: 4,
latent_size: 128,
dynamics: :mlp
)
# With decoder for reconstruction loss
{encoder, dynamics, reward_head, decoder} = WorldModel.build(
obs_size: 64,
action_size: 4,
dynamics: :gru,
use_decoder: true
)References
- Ha & Schmidhuber, "World Models" (2018)
- Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination" (Dreamer, 2020)
- Hafner et al., "Mastering Diverse Domains through World Models" (DreamerV3, 2023)
Summary
Functions
Build all world model components.
Build the observation decoder.
Build the dynamics model.
Build the observation encoder.
Build the reward prediction head.
Get the latent size of the world model.
Types
@type build_opt() :: {:obs_size, pos_integer()} | {:action_size, pos_integer()} | {:latent_size, pos_integer()} | {:hidden_size, pos_integer()} | {:dynamics, :mlp | :neural_ode | :gru} | {:use_decoder, boolean()}
Options for build/1.
Functions
@spec build([build_opt()]) :: {Axon.t(), Axon.t(), Axon.t()} | {Axon.t(), Axon.t(), Axon.t(), Axon.t()}
Build all world model components.
Options
:obs_size- Observation dimension (required):action_size- Action dimension (required):latent_size- Latent state dimension (default: 128):hidden_size- Hidden layer size (default: 256):dynamics- Dynamics model type::mlp,:neural_ode, or:gru(default::mlp):use_decoder- Include observation decoder (default: false)
Returns
{encoder, dynamics, reward_head} or
{encoder, dynamics, reward_head, decoder} if use_decoder: true.
Build the observation decoder.
Reconstructs observations from latent state: [batch, latent_size] → [batch, obs_size]
Build the dynamics model.
Predicts next latent state from current state and action:
[batch, latent_size + action_size] → [batch, latent_size]
Dynamics Variants
:mlp— Two dense layers with GELU:neural_ode— Shared-weight Euler integration (4 steps):gru— Gated recurrent update
Build the observation encoder.
Maps raw observations to latent state: [batch, obs_size] → [batch, latent_size]
Build the reward prediction head.
Predicts scalar reward from latent state: [batch, latent_size] → [batch]
@spec output_size(keyword()) :: pos_integer()
Get the latent size of the world model.