Multimodal Fusion Layers for Vision-Language Models.
Provides two fusion approaches for combining visual and text features:
1. MLP Projection (LLaVA-style)
The dominant approach in 2025-2026 VLMs. A pre-trained vision encoder's output tokens are projected through an MLP into the language model's embedding space and concatenated with text tokens.
Visual tokens [batch, num_patches, vision_dim]
|
[MLP: Linear -> GELU -> Linear]
|
Projected [batch, num_patches, llm_dim]
|
Concatenate with text tokens
|
[batch, num_patches + text_len, llm_dim]Used by: LLaVA, InternVL, Qwen-VL, PaliGemma, DeepSeek-VL
2. Cross-Attention (Flamingo-style)
Gated cross-attention blocks inserted into a language model. Visual features serve as keys/values, text features serve as queries.
Text hidden [batch, text_len, llm_dim]
|
[LayerNorm]
|
[Cross-Attention] <-- Visual features as K, V
|
[tanh(alpha) * output] <-- Learnable gate, init=0
|
[Residual]Used by: Flamingo, LLaMA 3.2 Vision
Usage
# MLP Projection
fused = Fusion.mlp_projection(visual_tokens, text_tokens,
vision_dim: 1024,
llm_dim: 4096
)
# Cross-Attention block
attended = Fusion.cross_attention_block(text_hidden, visual_tokens,
hidden_size: 4096,
num_heads: 32
)References
- LLaVA: "Visual Instruction Tuning" (Liu et al., NeurIPS 2023)
- Flamingo: "Few-Shot Visual Language Models" (Alayrac et al., NeurIPS 2022)
- LLaMA 3.2 Vision: Meta technical report
Summary
Functions
Build a multimodal fusion model (MLP projection).
Build a cross-attention fusion model.
Build an MLP projection fusion model.
Build a Perceiver Resampler that compresses visual tokens to a fixed count.
Get the output size of a fusion model.
Functions
Build a multimodal fusion model (MLP projection).
This is the registry-compatible entry point that delegates to build_mlp_projection/1.
Options
See build_mlp_projection/1 for available options.
Build a cross-attention fusion model.
Inserts gated cross-attention blocks that allow text tokens to attend to visual features. The gating starts at zero (preserving original LLM behavior) and gradually learns to incorporate visual information.
Options
:hidden_size- LLM hidden dimension (default: 256):vision_dim- Visual feature dimension (default: 1024):num_heads- Number of cross-attention heads (default: 4):num_visual_tokens- Number of visual tokens (default: 196):text_seq_len- Text sequence length (default: nil):num_layers- Number of cross-attention layers (default: 4)
Returns
An Axon model: (text_hidden, visual_features) -> attended_text
Build an MLP projection fusion model.
Takes visual tokens from a vision encoder and text token embeddings, projects the visual tokens to the LLM embedding space via a 2-layer MLP, and concatenates them.
Options
:vision_dim- Dimension of visual tokens from ViT (default: 1024):llm_dim- Dimension of LLM embedding space (default: 256):num_visual_tokens- Number of visual tokens (default: 196):text_seq_len- Maximum text sequence length (default: nil):compress_ratio- Group N visual tokens into 1 (Qwen-style, default: 1)
Returns
An Axon model: (visual_tokens, text_embeddings) -> fused_sequence
Build a Perceiver Resampler that compresses visual tokens to a fixed count.
Uses learned query embeddings that cross-attend to variable-length visual features, producing a fixed number of output tokens.
Options
:vision_dim- Input visual feature dimension (default: 1024):output_dim- Output dimension (default: 256):num_queries- Number of learned queries/output tokens (default: 64):num_layers- Number of resampler layers (default: 6):num_heads- Number of attention heads (default: 4):num_visual_tokens- Number of input visual tokens (default: 196)
Returns
An Axon model: visual_features -> resampled [batch, num_queries, output_dim]
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a fusion model.