Edifice.Sets.DeepSets (Edifice v0.2.0)

Copy Markdown View Source

Permutation-invariant set processing (Zaheer et al., 2017).

DeepSets processes sets of elements where the output is invariant to the ordering of inputs. This is achieved by processing each element independently through a shared network (phi), aggregating with a permutation-invariant operation (sum), and post-processing the aggregate (rho).

Architecture

Input Set [batch, set_size, element_dim]
      |
      v
+---------------------------+
| phi (per-element MLP):    |
|   For each x_i in set:    |
|     z_i = phi(x_i)        |
+---------------------------+
      |
      v
+---------------------------+
| Aggregate (sum):          |
|   z = SUM_i phi(x_i)      |
+---------------------------+
      |
      v
+---------------------------+
| rho (post-aggregation):   |
|   output = rho(z)         |
+---------------------------+
      |
      v
Output [batch, output_dim]

Key Property

The architecture output = rho(SUM(phi(x_i))) is provably a universal approximator for permutation-invariant functions on sets.

Usage

# Build a DeepSets model for set classification
model = DeepSets.build(
  input_dim: 3,
  hidden_size: 64,
  output_dim: 10,
  phi_sizes: [64, 64],
  rho_sizes: [64, 32]
)

# Process a batch of sets
# Input: {batch=4, set_size=20, element_dim=3}
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({4, 20, 3}, :f32), Axon.ModelState.empty())
output = predict_fn.(params, %{"input" => set_data})

References

  • "Deep Sets" (Zaheer et al., NeurIPS 2017)

Summary

Types

Options for build/1.

Functions

Build a DeepSets model for permutation-invariant set processing.

Types

build_opt()

@type build_opt() ::
  {:activation, atom()}
  | {:aggregation, :sum | :mean | :max}
  | {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:output_dim, pos_integer()}
  | {:phi_sizes, pos_integer()}
  | {:rho_sizes, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a DeepSets model for permutation-invariant set processing.

Options

  • :input_dim - Dimension of each set element (required)
  • :hidden_size - Intermediate dimension for phi output (default: 64)
  • :output_dim - Final output dimension (required)
  • :phi_sizes - Hidden layer sizes for per-element network (default: [64, 64])
  • :rho_sizes - Hidden layer sizes for post-aggregation network (default: [64])
  • :activation - Activation function (default: :relu)
  • :dropout - Dropout rate (default: 0.0)
  • :aggregation - Set aggregation: :sum, :mean, :max (default: :sum)

Returns

An Axon model. Input shape: {batch, set_size, input_dim}. Output shape: {batch, output_dim}.