FNet: Replacing Attention with Fourier Transform.
FNet replaces the self-attention sublayer in Transformers with an unparameterized Fourier Transform, achieving O(N log N) token mixing with no learnable attention parameters.
Key Innovation: FFT Mixing
Instead of computing attention weights, FNet applies FFT along the sequence axis to mix token information. This is parameter-free and achieves surprisingly competitive performance:
Standard Transformer: LayerNorm -> Self-Attention -> Residual
FNet: LayerNorm -> FFT Mixing -> ResidualThe FFT provides global token mixing (every token interacts with every other token through frequency-domain multiplication).
Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| FNet Block |
| |
| LayerNorm |
| -> FFT along seq axis |
| -> Take real part |
| -> Residual |
| |
| LayerNorm |
| -> Dense(hidden * 4) |
| -> GeLU |
| -> Dense(hidden) |
| -> Residual |
+-------------------------------------+
| (repeat for num_layers)
v
Last timestep -> [batch, hidden_size]Complexity
| Component | Transformer | FNet |
|---|---|---|
| Token mixing | O(N^2) | O(N^2)* |
| Parameters | Q,K,V weights | None (DFT) |
| Training speed | Baseline | ~7x faster |
| Quality | Baseline | 92-97% of BERT |
*Note: We use real-valued DFT matrix multiply instead of Nx.fft because EXLA's autodiff through complex FFT outputs triggers Nx.less/2 errors in LayerNorm's backward pass. For typical seq_len (30-128) and hidden_size (256-512), the O(N^2) matrix multiply is negligible vs the FFN layers.
Usage
model = FNet.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
dropout: 0.1
)References
- Paper: "FNet: Mixing Tokens with Fourier Transforms" (Lee-Thorp et al., Google 2021)
Summary
Functions
Build an FNet model for sequence processing.
Build a real-valued DFT matrix: DFT[k, n] = cos(2π k n / N).
Apply Fourier mixing using real-valued DFT matrix multiply.
Get the output size of an FNet model.
Calculate approximate parameter count for an FNet model.
Recommended default configuration for sequence processing.
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()}
Options for build/1.
Functions
Build an FNet model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of FNet blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
@spec dft_real_matrix(pos_integer()) :: Nx.Tensor.t()
Build a real-valued DFT matrix: DFT[k, n] = cos(2π k n / N).
For real inputs, Real(FFT(x)) = x @ DFT_real, avoiding complex arithmetic.
@spec fourier_mixing_real(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Apply Fourier mixing using real-valued DFT matrix multiply.
Computes Real(FFT2(x)) along both sequence and feature axes using precomputed cosine DFT matrices. This avoids Nx.fft entirely, preventing complex number issues in EXLA's backward pass.
Parameters
tensor- Input tensor [batch, seq_len, hidden_dim]dft_seq- Precomputed DFT matrix [seq_len, seq_len]dft_hidden- Precomputed DFT matrix [hidden_dim, hidden_dim]
Returns
DFT-mixed tensor [batch, seq_len, hidden_dim]
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of an FNet model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for an FNet model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.