Spiking Neural Network with surrogate gradients.
Spiking Neural Networks (SNNs) process information using discrete spikes rather than continuous activations. Neurons integrate input over time and fire when their membrane potential exceeds a threshold, then reset. This is biologically plausible and extremely energy-efficient on neuromorphic hardware (Intel Loihi, IBM TrueNorth).
Leaky Integrate-and-Fire (LIF) Neuron
The core compute unit:
V[t] = beta * V[t-1] + W * x[t] (leak + integrate)
spike[t] = V[t] > threshold (fire)
V[t] = V[t] - spike[t] * threshold (soft reset after spike)where:
- beta = exp(-dt/tau) is the membrane decay factor
- tau is the membrane time constant
- threshold is the firing threshold
Surrogate Gradients
The spike function (Heaviside step) is non-differentiable. We use a surrogate gradient for backpropagation: the derivative of a smooth approximation (sigmoid or fast sigmoid) replaces the true derivative.
Architecture
Input [batch, input_size]
| (presented for num_timesteps)
v
+----------------------------+
| LIF Layer 1 |
| V = beta*V + W*x |
| spike if V > threshold |
+----------------------------+
| (spike train)
v
+----------------------------+
| LIF Layer 2 |
+----------------------------+
|
v
+----------------------------+
| Rate Decoding |
| output = mean(spikes) |
+----------------------------+
|
v
Output [batch, output_size]Usage
model = SNN.build(
input_size: 256,
hidden_sizes: [128, 64],
output_size: 10,
num_timesteps: 25,
tau: 2.0,
threshold: 1.0
)References
- Neftci et al., "Surrogate Gradient Learning in SNNs" (2019)
- https://arxiv.org/abs/1901.09948
Summary
Functions
Build a Spiking Neural Network with LIF neurons and surrogate gradients.
Leaky Integrate-and-Fire neuron step.
Rate decoding: convert spike trains to firing rates.
Surrogate gradient for the non-differentiable spike function.
Types
@type build_opt() :: {:hidden_sizes, [pos_integer()]} | {:input_size, pos_integer()} | {:num_timesteps, pos_integer()} | {:output_size, pos_integer()} | {:tau, float()} | {:threshold, float()}
Options for build/1.
Functions
Build a Spiking Neural Network with LIF neurons and surrogate gradients.
The network processes input through multiple LIF neuron layers over several timesteps, then rate-decodes the output spike train into a continuous output.
Options
:input_size- Input feature dimension (required):hidden_sizes- List of hidden layer sizes (default: [256, 128]):output_size- Output dimension (required):num_timesteps- Number of simulation timesteps (default: 25):tau- Membrane time constant (default: 2.0):threshold- Firing threshold (default: 1.0)
Returns
An Axon model: [batch, input_size] -> [batch, output_size]
@spec lif_neuron(Nx.Tensor.t(), Nx.Tensor.t(), float(), float()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Leaky Integrate-and-Fire neuron step.
Computes one timestep of LIF dynamics:
V[t] = beta * V[t-1] + I[t]
spike[t] = surrogate_gradient(V[t] - threshold)
V[t] = V[t] * (1 - spike[t]) (reset)Parameters
membrane- Membrane potential from previous step[batch, hidden_size]input_current- Weighted input current[batch, hidden_size]beta- Membrane decay factor (= exp(-1/tau))threshold- Firing threshold
Returns
Tuple {new_membrane, spikes}:
new_membrane- Updated membrane potential[batch, hidden_size]spikes- Spike output[batch, hidden_size](0 or ~1)
@spec rate_decode(Nx.Tensor.t()) :: Nx.Tensor.t()
Rate decoding: convert spike trains to firing rates.
Computes the mean spike count over timesteps for each neuron. This is the simplest decoding scheme and works well for classification.
Parameters
spike_train- Spike tensor[batch, num_timesteps, hidden_size]
Returns
Firing rates [batch, hidden_size] in [0, 1]
@spec surrogate_gradient(Nx.Tensor.t(), float()) :: Nx.Tensor.t()
Surrogate gradient for the non-differentiable spike function.
The Heaviside step function has zero gradient almost everywhere. We use a fast sigmoid as a surrogate: during the forward pass we still get hard spikes, but gradients flow through the sigmoid approximation during backpropagation.
forward: spike = (x > 0) ? 1 : 0
backward: d_spike/dx = slope / (1 + slope * |x|)^2In practice with Nx, we use the sigmoid directly as a smooth approximation that is differentiable everywhere.
Parameters
x- Input tensor (membrane - threshold)slope- Steepness of the surrogate (default: 25.0)
Returns
Approximate spike values in [0, 1]