SchNet - Continuous-Filter Convolutional Neural Network.
SchNet processes molecular/atomic graphs using continuous-filter convolutions where edge weights are derived from interatomic distances via radial basis functions. Unlike discrete graph convolutions, SchNet operates on continuous geometry, making it suitable for molecular property prediction and atomistic simulations.
Architecture
Node Features [batch, num_nodes, input_dim]
Adjacency [batch, num_nodes, num_nodes] (interpreted as distances)
|
v
+--------------------------------------+
| Input Embedding |
+--------------------------------------+
|
v
+--------------------------------------+
| SchNet Interaction Block 1: |
| 1. RBF expansion of distances |
| 2. Filter-generating network |
| 3. Continuous convolution |
| 4. Update node features |
+--------------------------------------+
| (repeat N times)
v
Node Embeddings [batch, num_nodes, hidden_size]Continuous Convolution
For each pair of atoms (i, j) with distance d_ij:
- Expand d_ij into radial basis functions: e_k(d) = exp(-gamma * (d - mu_k)^2)
- Generate filter: W = Dense(RBF(d_ij))
- Convolve: x_i += SUM_j W(d_ij) * x_j
Usage
model = SchNet.build(
input_dim: 16,
hidden_size: 64,
num_interactions: 3,
num_filters: 64,
cutoff: 5.0,
num_classes: 1
)References
- Schutt et al., "SchNet: A continuous-filter convolutional neural network for modeling quantum interactions" (NeurIPS 2017)
- https://arxiv.org/abs/1706.08566
Summary
Functions
Build a SchNet model.
Single SchNet interaction block.
Get the output size of a SchNet model.
Types
@type build_opt() :: {:cutoff, float()} | {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:num_classes, pos_integer() | nil} | {:num_filters, pos_integer()} | {:num_interactions, pos_integer()} | {:num_rbf, pos_integer()} | {:pool, atom()}
Options for build/1.
Functions
Build a SchNet model.
Options
:input_dim- Input feature dimension per atom (required):hidden_size- Hidden dimension (default: 64):num_interactions- Number of interaction blocks (default: 3):num_filters- Number of continuous filters (default: 64):cutoff- Distance cutoff for interactions (default: 5.0):num_rbf- Number of radial basis functions (default: 20):num_classes- If provided, adds output projection (default: nil):pool- Global pooling mode for molecular properties (default: nil)
Returns
An Axon model with two inputs ("nodes" for atom features and "adjacency" for pairwise distances).
@spec interaction_block(Axon.t(), Axon.t(), pos_integer(), keyword()) :: Axon.t()
Single SchNet interaction block.
Options
:num_filters- Number of continuous filters (default: 64):cutoff- Distance cutoff (default: 5.0):num_rbf- Number of radial basis functions (default: 20):name- Layer name prefix
@spec output_size(keyword()) :: pos_integer()
Get the output size of a SchNet model.