Edifice.Graph.SchNet (Edifice v0.2.0)

Copy Markdown View Source

SchNet - Continuous-Filter Convolutional Neural Network.

SchNet processes molecular/atomic graphs using continuous-filter convolutions where edge weights are derived from interatomic distances via radial basis functions. Unlike discrete graph convolutions, SchNet operates on continuous geometry, making it suitable for molecular property prediction and atomistic simulations.

Architecture

Node Features [batch, num_nodes, input_dim]
Adjacency     [batch, num_nodes, num_nodes]  (interpreted as distances)
      |
      v
+--------------------------------------+
| Input Embedding                      |
+--------------------------------------+
      |
      v
+--------------------------------------+
| SchNet Interaction Block 1:          |
|   1. RBF expansion of distances      |
|   2. Filter-generating network       |
|   3. Continuous convolution          |
|   4. Update node features            |
+--------------------------------------+
      |  (repeat N times)
      v
Node Embeddings [batch, num_nodes, hidden_size]

Continuous Convolution

For each pair of atoms (i, j) with distance d_ij:

  1. Expand d_ij into radial basis functions: e_k(d) = exp(-gamma * (d - mu_k)^2)
  2. Generate filter: W = Dense(RBF(d_ij))
  3. Convolve: x_i += SUM_j W(d_ij) * x_j

Usage

model = SchNet.build(
  input_dim: 16,
  hidden_size: 64,
  num_interactions: 3,
  num_filters: 64,
  cutoff: 5.0,
  num_classes: 1
)

References

Summary

Types

Options for build/1.

Functions

Build a SchNet model.

Get the output size of a SchNet model.

Types

build_opt()

@type build_opt() ::
  {:cutoff, float()}
  | {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:num_classes, pos_integer() | nil}
  | {:num_filters, pos_integer()}
  | {:num_interactions, pos_integer()}
  | {:num_rbf, pos_integer()}
  | {:pool, atom()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a SchNet model.

Options

  • :input_dim - Input feature dimension per atom (required)
  • :hidden_size - Hidden dimension (default: 64)
  • :num_interactions - Number of interaction blocks (default: 3)
  • :num_filters - Number of continuous filters (default: 64)
  • :cutoff - Distance cutoff for interactions (default: 5.0)
  • :num_rbf - Number of radial basis functions (default: 20)
  • :num_classes - If provided, adds output projection (default: nil)
  • :pool - Global pooling mode for molecular properties (default: nil)

Returns

An Axon model with two inputs ("nodes" for atom features and "adjacency" for pairwise distances).

interaction_block(nodes, adjacency, hidden_size, opts \\ [])

@spec interaction_block(Axon.t(), Axon.t(), pos_integer(), keyword()) :: Axon.t()

Single SchNet interaction block.

Options

  • :num_filters - Number of continuous filters (default: 64)
  • :cutoff - Distance cutoff (default: 5.0)
  • :num_rbf - Number of radial basis functions (default: 20)
  • :name - Layer name prefix

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a SchNet model.