Edifice.Meta.MoE (Edifice v0.2.0)

Copy Markdown View Source

Mixture of Experts (MoE) for adaptive expert selection.

Overview

MoE routes each input to a subset of specialized "expert" networks based on a learned routing function. This allows the model to have much larger capacity while maintaining fast inference (only K experts are active per input).

Input x
    |
    v
+-----------------+
|     Router      | -> Selects top-K experts
|  (softmax gate) |
+--------+--------+
         |
  +------+------+------+------+
  v      v      v      v      v
+---+  +---+  +---+  +---+  +---+
|E1 |  |E2 |  |E3 |  |E4 |  |E5 |  (Experts)
+-+-+  +-+-+  +---+  +---+  +---+
  |      |         (inactive)
  v      v
 weighted sum
    |
    v
 Output

Expert Specialization

Different experts can specialize on different input patterns:

  • Expert 1: Common patterns (frequent states)
  • Expert 2: Transition states (changes between modes)
  • Expert 3: Edge cases (rare but important situations)
  • Expert 4: Fine-grained distinctions (subtle differences)

Routing Strategies

StrategyDescriptionLoad Balance
:top_kSelect K highest-scoring expertsRequires aux loss
:switchRoute to single best expertBest balance
:softWeighted sum of all expertsMost expensive
:hashDeterministic based on input hashPerfect balance

Usage

# Create MoE layer with 8 experts, top-2 routing
moe = MoE.build(
  input_size: 256,
  hidden_size: 512,
  num_experts: 8,
  top_k: 2,
  routing: :top_k
)

# With load balancing loss
{output, aux_loss} = MoE.forward_with_aux(moe, input, params)

Summary

Functions

Build a Mixture of Experts layer.

Build a complete MoE block with pre-norm and residual.

Build an MoE-enhanced backbone by replacing FFN layers with MoE.

Compute load balancing auxiliary loss.

Calculate theoretical speedup from MoE.

Get recommended MoE configuration.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:expert_type, :ffn | :glu | :mamba}
  | {:hidden_size, pos_integer()}
  | {:input_size, pos_integer()}
  | {:num_experts, pos_integer()}
  | {:output_size, pos_integer()}
  | {:routing, routing_strategy()}
  | {:top_k, pos_integer()}

Options for build/1.

routing_strategy()

@type routing_strategy() :: :top_k | :switch | :soft | :hash

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Mixture of Experts layer.

Options

Architecture:

  • :input_size - Input dimension (required)
  • :hidden_size - Expert hidden dimension (default: input_size * 4)
  • :output_size - Output dimension (default: input_size)
  • :num_experts - Number of expert networks (default: 8)
  • :top_k - Number of experts per input (default: 2)
  • :routing - Routing strategy (default: :top_k)

Regularization:

  • :dropout - Dropout rate (default: 0.1)
  • :capacity_factor - Max tokens per expert multiplier (default: 1.25)
  • :load_balance_weight - Auxiliary loss weight (default: 0.01)

Expert architecture:

  • :expert_type - :ffn, :glu, or :mamba (default: :ffn)

Returns

An Axon model for the MoE layer.

build_block(input, opts)

@spec build_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a complete MoE block with pre-norm and residual.

This wraps the MoE layer with the standard transformer block pattern.

build_moe_backbone(opts)

@spec build_moe_backbone(keyword()) :: Axon.t()

Build an MoE-enhanced backbone by replacing FFN layers with MoE.

Takes an existing backbone configuration and converts FFN sublayers to MoE.

Options

  • :backbone - Base backbone (:mamba, :attention, etc.)
  • :moe_every - Apply MoE every N layers (default: 2)
  • :num_experts - Experts per MoE layer (default: 8)
  • :top_k - Active experts per input (default: 2)

compute_aux_loss(router_probs, expert_mask, opts \\ [])

@spec compute_aux_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute load balancing auxiliary loss.

This loss encourages uniform expert utilization, preventing "expert collapse" where only a few experts are used.

Formula

aux_loss = alpha * num_experts * sum(f_i * P_i)

Where:

  • f_i = fraction of tokens routed to expert i
  • P_i = average router probability for expert i
  • alpha = load_balance_weight

A balanced router has aux_loss approximately 1.0.

estimate_speedup(num_experts, top_k, expert_fraction \\ 0.5)

@spec estimate_speedup(pos_integer(), pos_integer(), float()) :: float()

Calculate theoretical speedup from MoE.

Arguments

  • num_experts - Total number of experts
  • top_k - Active experts per input
  • expert_fraction - Fraction of model that is expert layers

Returns

Approximate FLOPs reduction ratio.