GSS: Gated State Space Model.
Implements the Gated State Space model from "Long Range Language Modeling via Gated State Spaces" (Mehta et al., 2023). GSS simplifies S4 by using fixed (learned but not input-dependent) A, B, C matrices combined with multiplicative gating for non-linearity.
Key Innovation: Fixed SSM + Multiplicative Gating
Unlike Mamba (where B, C, dt are input-dependent), GSS uses:
- Fixed diagonal A, B, C matrices (learned via
Axon.param) - Gating for input-dependent non-linearity:
gate = sigmoid(W_g * x) - Result: simpler than Mamba, more expressive than vanilla S4
Equations
# SSM with fixed parameters:
h_t = A * h_{t-1} + B * x_t # A, B are learned parameters (not input-dependent)
y_t = C * h_t # C is a learned parameter
# Gating for non-linearity:
gate_t = sigmoid(W_g * x_t + b_g)
output_t = gate_t * y_tArchitecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| GSS Block |
| LayerNorm -> [SSM path, Gate path] |
| SSM: linear -> scan(A,B) -> C*h |
| Gate: linear -> sigmoid |
| output = SSM * Gate |
| -> project -> residual |
| LayerNorm -> FFN -> residual |
+-------------------------------------+
| (repeat for num_layers)
v
Output [batch, hidden_size]Compared to Other SSMs
| Model | A,B,C | Gating | Complexity |
|---|---|---|---|
| S4 | Fixed (HiPPO) | None | O(L log L) |
| GSS | Fixed (learned) | Multiplicative | O(L) |
| Mamba | Input-dependent | SiLU | O(L) |
Usage
model = GSS.build(
embed_dim: 287,
hidden_size: 256,
state_size: 16,
num_layers: 4
)References
- Mehta et al., "Long Range Language Modeling via Gated State Spaces" (2023)
- https://arxiv.org/abs/2206.13947
Summary
Functions
Build a GSS model for sequence processing.
Default dropout rate
Default hidden dimension
Default number of layers
Default SSM state dimension
Get the output size of a GSS model.
Get recommended defaults.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:state_size, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a GSS model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):state_size- SSM state dimension (default: 16):num_layers- Number of GSS blocks (default: 4):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec default_dropout() :: float()
Default dropout rate
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec default_state_size() :: pos_integer()
Default SSM state dimension
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a GSS model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.