Edifice.Attention.FNet (Edifice v0.2.0)

Copy Markdown View Source

FNet: Replacing Attention with Fourier Transform.

FNet replaces the self-attention sublayer in Transformers with an unparameterized Fourier Transform, achieving O(N log N) token mixing with no learnable attention parameters.

Key Innovation: FFT Mixing

Instead of computing attention weights, FNet applies FFT along the sequence axis to mix token information. This is parameter-free and achieves surprisingly competitive performance:

Standard Transformer:  LayerNorm -> Self-Attention -> Residual
FNet:                  LayerNorm -> FFT Mixing     -> Residual

The FFT provides global token mixing (every token interacts with every other token through frequency-domain multiplication).

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       FNet Block                     |
|                                      |
|  LayerNorm                           |
|    -> FFT along seq axis             |
|    -> Take real part                 |
|  -> Residual                         |
|                                      |
|  LayerNorm                           |
|    -> Dense(hidden * 4)              |
|    -> GeLU                           |
|    -> Dense(hidden)                  |
|  -> Residual                         |
+-------------------------------------+
      | (repeat for num_layers)
      v
Last timestep -> [batch, hidden_size]

Complexity

ComponentTransformerFNet
Token mixingO(N^2)O(N^2)*
ParametersQ,K,V weightsNone (DFT)
Training speedBaseline~7x faster
QualityBaseline92-97% of BERT

*Note: We use real-valued DFT matrix multiply instead of Nx.fft because EXLA's autodiff through complex FFT outputs triggers Nx.less/2 errors in LayerNorm's backward pass. For typical seq_len (30-128) and hidden_size (256-512), the O(N^2) matrix multiply is negligible vs the FFN layers.

Usage

model = FNet.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  dropout: 0.1
)

References

  • Paper: "FNet: Mixing Tokens with Fourier Transforms" (Lee-Thorp et al., Google 2021)

Summary

Types

Options for build/1.

Functions

Build an FNet model for sequence processing.

Build a real-valued DFT matrix: DFT[k, n] = cos(2π k n / N).

Apply Fourier mixing using real-valued DFT matrix multiply.

Get the output size of an FNet model.

Calculate approximate parameter count for an FNet model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an FNet model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of FNet blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

dft_real_matrix(n)

@spec dft_real_matrix(pos_integer()) :: Nx.Tensor.t()

Build a real-valued DFT matrix: DFT[k, n] = cos(2π k n / N).

For real inputs, Real(FFT(x)) = x @ DFT_real, avoiding complex arithmetic.

fourier_mixing_real(tensor, dft_seq, dft_hidden)

@spec fourier_mixing_real(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  Nx.Tensor.t()

Apply Fourier mixing using real-valued DFT matrix multiply.

Computes Real(FFT2(x)) along both sequence and feature axes using precomputed cosine DFT matrices. This avoids Nx.fft entirely, preventing complex number issues in EXLA's backward pass.

Parameters

  • tensor - Input tensor [batch, seq_len, hidden_dim]
  • dft_seq - Precomputed DFT matrix [seq_len, seq_len]
  • dft_hidden - Precomputed DFT matrix [hidden_dim, hidden_dim]

Returns

DFT-mixed tensor [batch, seq_len, hidden_dim]

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an FNet model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an FNet model.