BiMamba: Bidirectional Mamba for non-causal sequence modeling.
Extends Mamba with a backward pass for tasks where future context is available (e.g., classification, fill-in-the-blank, offline sequence analysis). Processes the sequence in both directions and combines the outputs.
Key Innovation: Bidirectional SSM
Standard Mamba is causal (left-to-right only). BiMamba runs two parallel SSMs:
Forward: h_f[t] = A_f * h_f[t-1] + B_f * x[t] (t = 1..L)
Backward: h_b[t] = A_b * h_b[t+1] + B_b * x[t] (t = L..1)
Output: y[t] = project(concat(h_f[t], h_b[t]))Architecture
Input [batch, seq_len, embed_dim]
|
v
+-----------------------+
| Input Projection |
+-----------------------+
|
v
+-----------------------+
| BiMamba Block x N |
| LayerNorm |
| +-- Forward SSM --+ |
| | | |
| +-- Backward SSM -+ |
| | | |
| +--- combine -----+ |
| Projection + Residual|
| FFN |
+-----------------------+
|
v
[batch, hidden_size] (last timestep)Use Cases
BiMamba is suited for offline tasks where the full sequence is available:
- Replay analysis (post-game)
- Sequence classification
- Bidirectional feature extraction
For real-time inference (causal), use Edifice.SSM.Mamba instead.
Usage
model = BiMamba.build(
embed_dim: 287,
hidden_size: 256,
state_size: 16,
num_layers: 4
)Reference
- Concept based on bidirectional extensions to Mamba (multiple concurrent works)
- Original Mamba: https://arxiv.org/abs/2312.00752
Summary
Functions
Build a BiMamba model for bidirectional sequence processing.
Build a single BiMamba block with forward and backward SSMs.
Get the output size of a BiMamba model.
Calculate approximate parameter count for a BiMamba model.
Get recommended defaults.
Types
@type build_opt() :: {:combine, :add | :concat} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:state_size, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a BiMamba model for bidirectional sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):state_size- SSM state dimension N (default: 16):num_layers- Number of BiMamba blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length (default: 60):combine- How to merge directions: :add or :concat (default: :add)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build a single BiMamba block with forward and backward SSMs.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a BiMamba model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a BiMamba model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.