Hypernetworks that generate weights for a target network.
A hypernetwork is a neural network that produces the weights for another neural network (the target network). This enables:
- Conditional computation: Different inputs produce different target weights
- Weight sharing: One hypernetwork can generate weights for many target layers
- Task adaptation: Condition on task embeddings to generate task-specific networks
- Compression: The hypernetwork can be smaller than the target weight space
Architecture
Conditioning Input [batch, conditioning_size]
|
v
+----------------------------+
| Hypernetwork |
| (generates weight chunks) |
+----------------------------+
|
v
Weight Matrices for Target Network
[W1: in1 x out1, W2: in2 x out2, ...]
|
v
+----------------------------+
| Target Network |
| (uses generated weights) |
+----------------------------+
|
v
Output [batch, final_output_size]Usage
# Build hypernetwork
model = Hypernetwork.build(
conditioning_size: 64,
target_layer_sizes: [{128, 64}, {64, 32}],
hidden_sizes: [256, 256]
)References
- Ha et al., "HyperNetworks" (2016)
- https://arxiv.org/abs/1609.09106
Summary
Functions
Apply hypernetwork-generated weights to compute the target network output.
Build a hypernetwork that generates weights for a target network.
Build a weight generator network.
Types
@type build_opt() :: {:activation, atom()} | {:conditioning_size, pos_integer()} | {:hidden_sizes, [pos_integer()]} | {:input_size, pos_integer()} | {:target_layer_sizes, pos_integer()}
Options for build/1.
Functions
@spec apply_generated_weights( Axon.t(), [Axon.t()], [{pos_integer(), pos_integer()}], keyword() ) :: Axon.t()
Apply hypernetwork-generated weights to compute the target network output.
Takes data input and generated weight vectors, reshapes them into proper weight matrices and biases, and applies them sequentially.
Parameters
data_input- Axon node with data[batch, input_dim]weight_generators- List of Axon nodes, each producing weight paramstarget_layer_sizes- List of{in_dim, out_dim}tuples
Options
:activation- Activation function between layers (default: :relu)
Returns
An Axon node with shape [batch, last_out_dim]
Build a hypernetwork that generates weights for a target network.
The hypernetwork takes a conditioning input and produces weight matrices for each layer of the target network. The target network then processes data input using these generated weights.
Options
:conditioning_size- Dimension of the conditioning input (required):target_layer_sizes- List of{input_dim, output_dim}tuples for each target layer (required):hidden_sizes- Hidden layer sizes for the weight generator (default: [256, 256]):input_size- Size of the data input to the target network (required):activation- Activation for target network layers (default: :relu)
Returns
An Axon model taking conditioning [batch, conditioning_size] and
data [batch, input_size], producing [batch, last_output_dim].
@spec build_weight_generator(Axon.t(), pos_integer(), pos_integer(), keyword()) :: Axon.t()
Build a weight generator network.
Takes a conditioning input and outputs a flattened weight matrix and bias vector for one target layer.
Parameters
conditioning- Axon node with conditioning input[batch, conditioning_size]target_in- Input dimension of the target layertarget_out- Output dimension of the target layer
Options
:hidden_sizes- Hidden layer sizes (default: [256, 256]):name- Layer name prefix
Returns
An Axon node producing concatenated weight + bias: [batch, target_in * target_out + target_out]