Edifice.SSM.MambaHillisSteele (Edifice v0.2.0)

Copy Markdown View Source

Mamba variant using Hillis-Steele parallel scan.

Hillis-Steele vs Blelloch

Blelloch (standard Mamba): O(L) work, O(log L) depth

  • Work-efficient: only half the elements active at each level
  • Fewer total operations

Hillis-Steele: O(L log L) work, O(log L) depth

  • ALL elements active at every level
  • More parallelism per level = better GPU occupancy
  • May be faster despite more total work

Algorithm

Level 0: [1] [2] [3] [4] [5] [6] [7] [8]
Level 1: [1] [1+2] [2+3] [3+4] [4+5] [5+6] [6+7] [7+8]  (stride 1, ALL elements)
Level 2: [1] [1+2] [1-3] [1-4] [2-5] [3-6] [4-7] [5-8]  (stride 2, ALL elements)
Level 3: [1] [1+2] [1-3] [1-4] [1-5] [1-6] [1-7] [1-8]  (stride 4, ALL elements)

Usage

model = MambaHillisSteele.build(embed_dim: 287, hidden_size: 256)

Summary

Types

Options for build/1.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a MambaHillisSteele model for sequence processing.

Same API as Mamba.build/1.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

See Edifice.SSM.Common.output_size/1.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

See Edifice.SSM.Common.param_count/1.