Residual Network (ResNet) implementation.
Deep residual networks use skip connections to enable training of very deep networks by mitigating the vanishing gradient problem. Each residual block adds its input to its output, allowing gradients to flow directly through identity shortcuts.
Architecture
Input [batch, H, W, C]
|
+-----v-------+
| Stem | 7x7 conv stride 2, BN, ReLU, 3x3 max pool
+--------------+
|
+-----v-------+
| Stage 1 | N residual blocks at initial_channels
+--------------+
|
+-----v-------+
| Stage 2 | N residual blocks at initial_channels * 2 (stride 2)
+--------------+
|
+-----v-------+
| Stage 3 | N residual blocks at initial_channels * 4 (stride 2)
+--------------+
|
+-----v-------+
| Stage 4 | N residual blocks at initial_channels * 8 (stride 2)
+--------------+
|
+-----v-------+
| Global AvgPool|
+--------------+
|
+-----v-------+
| Dense | num_classes outputs
+--------------+Configurations
| Model | block_sizes | Block Type | Params |
|---|---|---|---|
| ResNet-18 | [2, 2, 2, 2] | residual | ~11M |
| ResNet-34 | [3, 4, 6, 3] | residual | ~21M |
| ResNet-50 | [3, 4, 6, 3] | bottleneck | ~25M |
| ResNet-101 | [3, 4, 23, 3] | bottleneck | ~44M |
| ResNet-152 | [3, 8, 36, 3] | bottleneck | ~60M |
Usage
# ResNet-18 for CIFAR-10
model = ResNet.build(
input_shape: {nil, 32, 32, 3},
num_classes: 10,
block_sizes: [2, 2, 2, 2],
initial_channels: 64
)
# ResNet-50 with bottleneck blocks
model = ResNet.build(
input_shape: {nil, 224, 224, 3},
num_classes: 1000,
block_sizes: [3, 4, 6, 3],
block_type: :bottleneck,
initial_channels: 64
)
Summary
Functions
Build a bottleneck residual block.
Build a ResNet model.
Get the output size (num_classes) for a ResNet model.
Build a single residual block.
Types
@type build_opt() :: {:block_sizes, pos_integer()} | {:block_type, :residual | :bottleneck} | {:dropout, float()} | {:initial_channels, pos_integer()} | {:input_shape, tuple()} | {:num_classes, pos_integer() | nil}
Options for build/1.
Functions
@spec bottleneck_block(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a bottleneck residual block.
Structure: 1x1 conv (reduce) -> BN -> ReLU -> 3x3 conv -> BN -> ReLU -> 1x1 conv (expand) -> BN + skip -> ReLU
Bottleneck blocks use a 4x expansion factor: the final 1x1 conv outputs
channels * 4 features. This is more parameter-efficient for deep networks.
Parameters
input- Input Axon node[batch, H, W, C]channels- Number of bottleneck channels (output will bechannels * 4)opts- Options::strides- Convolution stride for downsampling (default: 1):expansion- Expansion factor for output channels (default: 4):name- Layer name prefix (default: "bottleneck")
Returns
An Axon node with shape [batch, H', W', channels * expansion].
Build a ResNet model.
Options
:input_shape- Input shape as{nil, height, width, channels}(required):num_classes- Number of output classes (default: 10):block_sizes- List of block counts per stage (default: [2, 2, 2, 2] for ResNet-18):block_type-:residualor:bottleneck(default: :residual):initial_channels- Channels after stem conv (default: 64):dropout- Dropout rate before final dense layer (default: 0.0)
Returns
An Axon model outputting [batch, num_classes].
@spec output_size(keyword()) :: pos_integer()
Get the output size (num_classes) for a ResNet model.
@spec residual_block(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a single residual block.
Structure: conv 3x3 -> BN -> ReLU -> conv 3x3 -> BN + skip -> ReLU
When input and output channels differ, a 1x1 projection is applied to the skip connection to match dimensions.
Parameters
input- Input Axon node[batch, H, W, C]channels- Number of output channelsopts- Options::strides- Convolution stride for downsampling (default: 1):name- Layer name prefix (default: "res_block")
Returns
An Axon node with shape [batch, H', W', channels].