Permutation-invariant set processing (Zaheer et al., 2017).
DeepSets processes sets of elements where the output is invariant to the ordering of inputs. This is achieved by processing each element independently through a shared network (phi), aggregating with a permutation-invariant operation (sum), and post-processing the aggregate (rho).
Architecture
Input Set [batch, set_size, element_dim]
|
v
+---------------------------+
| phi (per-element MLP): |
| For each x_i in set: |
| z_i = phi(x_i) |
+---------------------------+
|
v
+---------------------------+
| Aggregate (sum): |
| z = SUM_i phi(x_i) |
+---------------------------+
|
v
+---------------------------+
| rho (post-aggregation): |
| output = rho(z) |
+---------------------------+
|
v
Output [batch, output_dim]Key Property
The architecture output = rho(SUM(phi(x_i))) is provably a universal approximator for permutation-invariant functions on sets.
Usage
# Build a DeepSets model for set classification
model = DeepSets.build(
input_dim: 3,
hidden_size: 64,
output_dim: 10,
phi_sizes: [64, 64],
rho_sizes: [64, 32]
)
# Process a batch of sets
# Input: {batch=4, set_size=20, element_dim=3}
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({4, 20, 3}, :f32), Axon.ModelState.empty())
output = predict_fn.(params, %{"input" => set_data})References
- "Deep Sets" (Zaheer et al., NeurIPS 2017)
Summary
Functions
Build a DeepSets model for permutation-invariant set processing.
Types
@type build_opt() :: {:activation, atom()} | {:aggregation, :sum | :mean | :max} | {:dropout, float()} | {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:output_dim, pos_integer()} | {:phi_sizes, pos_integer()} | {:rho_sizes, pos_integer()}
Options for build/1.
Functions
Build a DeepSets model for permutation-invariant set processing.
Options
:input_dim- Dimension of each set element (required):hidden_size- Intermediate dimension for phi output (default: 64):output_dim- Final output dimension (required):phi_sizes- Hidden layer sizes for per-element network (default: [64, 64]):rho_sizes- Hidden layer sizes for post-aggregation network (default: [64]):activation- Activation function (default: :relu):dropout- Dropout rate (default: 0.0):aggregation- Set aggregation: :sum, :mean, :max (default: :sum)
Returns
An Axon model. Input shape: {batch, set_size, input_dim}.
Output shape: {batch, output_dim}.