View Source Axon.Layers (Axon v0.5.1)

Functional implementations of common neural network layer operations.

Layers are the building blocks of neural networks. These functional implementations can be used to express higher-level constructs using fundamental building blocks. Neural network layers are stateful with respect to their parameters. These implementations do not assume the responsibility of managing state - instead opting to delegate this responsibility to the caller.

Basic neural networks can be seen as a composition of functions:

input
|> dense(w1, b1)
|> relu()
|> dense(w2, b2)
|> softmax()

These kinds of models are often referred to as deep feedforward networks or multilayer perceptrons (MLPs) because information flows forward through the network with no feedback connections. Mathematically, a feedforward network can be represented as:

$$f(x) = f^{(3)}(f^{(2)}(f^{(1)}(x)))$$

You can see a similar pattern emerge if we condense the call stack in the previous example:

softmax(dense(relu(dense(input, w1, b1)), w2, b2))

The chain structure shown here is the most common structure used in neural networks. You can consider each function $f^{(n)}$ as a layer in the neural network - for example $f^{(2)} is the 2nd layer in the network. The number of function calls in the structure is the depth of the network. This is where the term deep learning comes from.

Neural networks are often written as the mapping:

$$y = f(x; \theta)$$

Where $x$ is the input to the neural network and $\theta$ are the set of learned parameters. In Elixir, you would write this:

y = model(input, params)

From the previous example, params would represent the collection:

{w1, b1, w2, b2}

where w1 and w2 are layer kernels, and b1 and b2 are layer biases.

Link to this section Summary

Layers: Linear

Functional implementation of a bilinear layer.

Functional implementation of a dense layer.

Computes embedding by treating kernel matrix as a lookup table for discrete tokens.

Layers: Dropout

Functional implementation of an alpha dropout layer.

Functional implementation of a dropout layer.

Functional implementation of a feature alpha dropout layer.

Functional implementation of an n-dimensional spatial dropout layer.

Layers: Pooling

Functional implementation of general dimensional adaptive average pooling.

Functional implementation of general dimensional adaptive power average pooling.

Functional implementation of general dimensional adaptive max pooling.

A general dimensional functional average pooling layer.

Functional implementation of global average pooling which averages across the spatial dimensions of the input such that the only remaining dimensions are the batch and feature dimensions.

Functional implementation of global LP pooling which computes the following function across spatial dimensions of the input

Functional implementation of global max pooling which computes maximums across the spatial dimensions of the input such that the only remaining dimensions are the batch and feature dimensions.

Functional implementation of a general dimensional power average pooling layer.

Functional implementation of a general dimensional max pooling layer.

Layers: Normalization

Functional implementation of batch normalization.

Functional implementation of group normalization.

Functional implementation of instance normalization.

Functional implementation of layer normalization.

Layers: Shape

Flattens input to shape of {batch, units} by folding outer dimensions.

Resizes a batch of tensors to the given shape using one of a number of sampling methods.

Functions: Convolutional

Functional implementation of a general dimensional convolutional layer.

Functional implementation of a general dimensional transposed convolutional layer.

Functional implementation of a general dimensional depthwise convolution.

Functional implementation of a 2-dimensional separable depthwise convolution.

Functional implementation of a 3-dimensional separable depthwise convolution.

Link to this section Layers: Linear

Link to this function

bilinear(input1, input2, kernel, bias \\ 0, opts \\ [])

View Source

Functional implementation of a bilinear layer.

Bilinear transformation of the input such that:

$$y = x_1^{T}Ax_2 + b$$

parameter-shapes

Parameter Shapes

  • input1 - {batch_size, ..., input1_features}
  • input2 - {batch_size, ..., input2_features}
  • kernel - {out_features, input1_features, input2_features}

output-shape

Output Shape

{batch_size, ..., output_features}

examples

Examples

iex> inp1 = Nx.iota({3, 2}, type: {:f, 32})
iex> inp2 = Nx.iota({3, 4}, type: {:f, 32})
iex> kernel = Nx.iota({1, 2, 4}, type: {:f, 32})
iex> bias = Nx.tensor(1.0)
iex> Axon.Layers.bilinear(inp1, inp2, kernel, bias)
#Nx.Tensor<
  f32[3][1]
  [
    [39.0],
    [455.0],
    [1319.0]
  ]
>
Link to this function

dense(input, kernel, bias \\ 0, opts \\ [])

View Source

Functional implementation of a dense layer.

Linear transformation of the input such that:

$$y = xW^T + b$$

A dense layer or fully connected layer transforms the input using the given kernel matrix and bias to compute:

Nx.dot(input, kernel) + bias

Typically, both kernel and bias are learnable parameters trained using gradient-based optimization.

parameter-shapes

Parameter Shapes

  • input - {batch_size, * input_features}
  • kernel - {input_features, output_features}
  • bias - {} or {output_features}

output-shape

Output Shape

{batch_size, *, output_features}

examples

Examples

iex> input = Nx.tensor([[1.0, 0.5, 1.0, 0.5], [0.0, 0.0, 0.0, 0.0]], type: {:f, 32})
iex> kernel = Nx.tensor([[0.2], [0.3], [0.5], [0.8]], type: {:f, 32})
iex> bias = Nx.tensor([1.0], type: {:f, 32})
iex> Axon.Layers.dense(input, kernel, bias)
#Nx.Tensor<
  f32[2][1]
  [
    [2.25],
    [1.0]
  ]
>
Link to this function

embedding(input, kernel, arg3 \\ [])

View Source

Computes embedding by treating kernel matrix as a lookup table for discrete tokens.

input is a vector of discrete values, typically representing tokens (e.g. words, characters, etc.) from a vocabulary. kernel is a kernel matrix of shape {vocab_size, embedding_size} from which the dense embeddings will be drawn.

parameter-shapes

Parameter Shapes

  • input - {batch_size, ..., seq_len}
  • kernel - {vocab_size, embedding_size}

examples

Examples

iex> input = Nx.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
iex> kernels = Nx.tensor([
...>  [0.46299999952316284, 0.5562999844551086, 0.18170000612735748],
...>  [0.9801999926567078, 0.09780000150203705, 0.5333999991416931],
...>  [0.6980000138282776, 0.9240999817848206, 0.23479999601840973],
...>  [0.31929999589920044, 0.42250001430511475, 0.7865999937057495],
...>  [0.5519000291824341, 0.5662999749183655, 0.20559999346733093],
...>  [0.1898999959230423, 0.9311000108718872, 0.8356000185012817],
...>  [0.6383000016212463, 0.8794000148773193, 0.5282999873161316],
...>  [0.9523000121116638, 0.7597000002861023, 0.08250000327825546],
...>  [0.6622999906539917, 0.02329999953508377, 0.8205999732017517],
...>  [0.9855999946594238, 0.36419999599456787, 0.5372999906539917]
...> ])
iex> Axon.Layers.embedding(input, kernels)
#Nx.Tensor<
  f32[2][4][3]
  [
    [
      [0.9801999926567078, 0.09780000150203705, 0.5333999991416931],
      [0.6980000138282776, 0.9240999817848206, 0.23479999601840973],
      [0.5519000291824341, 0.5662999749183655, 0.20559999346733093],
      [0.1898999959230423, 0.9311000108718872, 0.8356000185012817]
    ],
    [
      [0.5519000291824341, 0.5662999749183655, 0.20559999346733093],
      [0.31929999589920044, 0.42250001430511475, 0.7865999937057495],
      [0.6980000138282776, 0.9240999817848206, 0.23479999601840973],
      [0.9855999946594238, 0.36419999599456787, 0.5372999906539917]
    ]
  ]
>

Link to this section Layers: Dropout

Link to this function

alpha_dropout(input, key, opts \\ [])

View Source

Functional implementation of an alpha dropout layer.

Alpha dropout is a type of dropout that forces the input to have zero mean and unit standard deviation. Randomly masks some elements and scales to enforce self-normalization.

options

Options

  • :rate - dropout rate. Used to determine probability a connection will be dropped. Required.

# :noise_shape - input noise shape. Shape of mask which can be useful

for broadcasting `mask` across feature channels or other dimensions.
Defaults to shape of input tensor.

references

References

Link to this function

dropout(input, key, opts \\ [])

View Source

Functional implementation of a dropout layer.

Applies a mask to some elements of the input tensor with probability rate and scales the input tensor by a factor of $\frac{1}{1 - rate}$.

Dropout is a form of regularization that helps prevent overfitting by preventing models from becoming too reliant on certain connections. Dropout can somewhat be thought of as learning an ensemble of models with random connections masked.

options

Options

  • :rate - dropout rate. Used to determine probability a connection will be dropped. Required.

  • :noise_shape - input noise shape. Shape of mask which can be useful for broadcasting mask across feature channels or other dimensions. Defaults to shape of input tensor.

references

References

Link to this function

feature_alpha_dropout(input, key, opts \\ [])

View Source

Functional implementation of a feature alpha dropout layer.

Feature alpha dropout applies dropout in the same manner as spatial dropout; however, it also enforces self-normalization by masking inputs with the SELU activation function and scaling unmasked inputs.

options

Options

  • :rate - dropout rate. Used to determine probability a connection will be dropped. Required.

# :noise_shape - input noise shape. Shape of mask which can be useful

for broadcasting `mask` across feature channels or other dimensions.
Defaults to shape of input tensor.
Link to this function

spatial_dropout(input, key, opts \\ [])

View Source

Functional implementation of an n-dimensional spatial dropout layer.

Applies a mask to entire feature maps instead of individual elements. This is done by calculating a mask shape equal to the spatial dimensions of the input tensor with 1 channel, and then broadcasting the mask across the feature dimension of the input tensor.

options

Options

  • :rate - dropout rate. Used to determine probability a connection will be dropped. Required.

# :noise_shape - input noise shape. Shape of mask which can be useful

for broadcasting `mask` across feature channels or other dimensions.
Defaults to shape of input tensor.

references

References

Link to this section Layers: Pooling

Link to this function

adaptive_avg_pool(input, opts \\ [])

View Source

Functional implementation of general dimensional adaptive average pooling.

Adaptive pooling allows you to specify the desired output size of the transformed input. This will automatically adapt the window size and strides to obtain the desired output size. It will then perform average pooling using the calculated window size and strides.

Adaptive pooling can be useful when working on multiple inputs with different spatial input shapes. You can guarantee the output of an adaptive pooling operation is always the same size regardless of input shape.

options

Options

  • :output_size - spatial output size. Must be a tuple with size equal to the spatial dimensions in the input tensor. Required.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

adaptive_lp_pool(input, opts \\ [])

View Source

Functional implementation of general dimensional adaptive power average pooling.

Computes:

$$f(X) = qrt[p]{ um_{x in X} x^{p}}$$

Adaptive pooling allows you to specify the desired output size of the transformed input. This will automatically adapt the window size and strides to obtain the desired output size. It will then perform max pooling using the calculated window size and strides.

Adaptive pooling can be useful when working on multiple inputs with different spatial input shapes. You can guarantee the output of an adaptive pooling operation is always the same size regardless of input shape.

options

Options

  • :norm - $p$ from above equation. Defaults to 2.

  • :output_size - spatial output size. Must be a tuple with size equal to the spatial dimensions in the input tensor. Required.

Link to this function

adaptive_max_pool(input, opts \\ [])

View Source

Functional implementation of general dimensional adaptive max pooling.

Adaptive pooling allows you to specify the desired output size of the transformed input. This will automatically adapt the window size and strides to obtain the desired output size. It will then perform max pooling using the calculated window size and strides.

Adaptive pooling can be useful when working on multiple inputs with different spatial input shapes. You can guarantee the output of an adaptive pooling operation is always the same size regardless of input shape.

options

Options

  • :output_size - spatial output size. Must be a tuple with size equal to the spatial dimensions in the input tensor. Required.
Link to this function

avg_pool(input, opts \\ [])

View Source

A general dimensional functional average pooling layer.

Pooling is applied to the spatial dimension of the input tensor. Average pooling returns the average of all elements in valid windows in the input tensor. It is often used after convolutional layers to downsample the input even further.

options

Options

  • kernel_size - window size. Rank must match spatial dimension of the input tensor. Required.

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :window_dilations - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Can be scalar or list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

global_avg_pool(input, opts \\ [])

View Source

Functional implementation of global average pooling which averages across the spatial dimensions of the input such that the only remaining dimensions are the batch and feature dimensions.

Assumes data is configured in a channels-first like format.

parameter-shapes

Parameter Shapes

  • input - {batch_size, features, s1, ..., sN}

options

Options

  • :keep_axes - option to keep reduced axes with size 1 for each reduced dimensions. Defaults to false

examples

Examples

iex> Axon.Layers.global_avg_pool(Nx.iota({3, 2, 3}, type: {:f, 32}), channels: :first)
#Nx.Tensor<
  f32[3][2]
  [
    [1.0, 4.0],
    [7.0, 10.0],
    [13.0, 16.0]
  ]
>

iex> Axon.Layers.global_avg_pool(Nx.iota({1, 3, 2, 2}, type: {:f, 32}), channels: :first, keep_axes: true)
#Nx.Tensor<
  f32[1][3][1][1]
  [
    [
      [
        [1.5]
      ],
      [
        [5.5]
      ],
      [
        [9.5]
      ]
    ]
  ]
>
Link to this function

global_lp_pool(input, opts \\ [])

View Source

Functional implementation of global LP pooling which computes the following function across spatial dimensions of the input:

$$f(X) = qrt[p]{ um_{x in X} x^{p}}$$

Where $p$ is given by the keyword argument :norm. As $p$ approaches infinity, it becomes equivalent to max pooling.

Assumes data is configured in a channels-first like format.

parameter-shapes

Parameter Shapes

  • input - {batch_size, s1, ..., sN, features}

options

Options

  • :keep_axes - option to keep reduced axes with size 1 for each reduced dimensions. Defaults to false
  • :norm - $p$ in above function. Defaults to 2

examples

Examples

iex> Axon.Layers.global_lp_pool(Nx.iota({3, 2, 3}, type: {:f, 32}), norm: 1, channels: :first)
#Nx.Tensor<
  f32[3][2]
  [
    [3.0, 12.0],
    [21.0, 30.0],
    [39.0, 48.0]
  ]
>

iex> Axon.Layers.global_lp_pool(Nx.iota({1, 3, 2, 2}, type: {:f, 16}), keep_axes: true, channels: :first)
#Nx.Tensor<
  f16[1][3][1][1]
  [
    [
      [
        [3.7421875]
      ],
      [
        [11.2265625]
      ],
      [
        [19.125]
      ]
    ]
  ]
>
Link to this function

global_max_pool(input, opts \\ [])

View Source

Functional implementation of global max pooling which computes maximums across the spatial dimensions of the input such that the only remaining dimensions are the batch and feature dimensions.

Assumes data is configured in a channels-first like format.

parameter-shapes

Parameter Shapes

  • input - {batch_size, s1, ..., sN, features}

options

Options

  • :keep_axes - option to keep reduced axes with size 1 for each reduced dimensions. Defaults to false

examples

Examples

iex> Axon.Layers.global_max_pool(Nx.iota({3, 2, 3}, type: {:f, 32}), channels: :first)
#Nx.Tensor<
  f32[3][2]
  [
    [2.0, 5.0],
    [8.0, 11.0],
    [14.0, 17.0]
  ]
>

iex> Axon.Layers.global_max_pool(Nx.iota({1, 3, 2, 2}, type: {:f, 32}), keep_axes: true, channels: :first)
#Nx.Tensor<
  f32[1][3][1][1]
  [
    [
      [
        [3.0]
      ],
      [
        [7.0]
      ],
      [
        [11.0]
      ]
    ]
  ]
>
Link to this function

lp_pool(input, opts \\ [])

View Source

Functional implementation of a general dimensional power average pooling layer.

Pooling is applied to the spatial dimension of the input tensor. Power average pooling computes the following function on each valid window of the input tensor:

$$f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}$$

Where $p$ is given by the keyword argument :norm. As $p$ approaches infinity, it becomes equivalent to max pooling.

options

Options

  • :norm - $p$ from above equation. Defaults to 2.

  • :kernel_size - window size. Rank must match spatial dimension of the input tensor. Required.

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to size of kernel.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :window_dilations - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Can be scalar or list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

examples

Examples

iex> t = Nx.tensor([[[0.9450, 0.4684, 1.8146], [1.2663, 0.4354, -0.0781], [-0.4759, 0.3251, 0.8742]]], type: {:f, 32})
iex> Axon.Layers.lp_pool(t, kernel_size: 2, norm: 2, channels: :first)
#Nx.Tensor<
  f32[1][3][1]
  [
    [
      [1.0547149181365967],
      [1.3390626907348633],
      [0.5763426423072815]
    ]
  ]
>
Link to this function

max_pool(input, opts \\ [])

View Source

Functional implementation of a general dimensional max pooling layer.

Pooling is applied to the spatial dimension of the input tensor. Max pooling returns the maximum element in each valid window of the input tensor. It is often used after convolutional layers to downsample the input even further.

options

Options

  • kernel_size - window size. Rank must match spatial dimension of the input tensor. Required.

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to size of kernel.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :window_dilations - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Can be scalar or list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

examples

Examples

iex> t = Nx.tensor([[
...> [0.051500000059604645, -0.7042999863624573, -0.32899999618530273],
...> [-0.37130001187324524, 1.6191999912261963, -0.11829999834299088],
...> [0.7099999785423279, 0.7282999753952026, -0.18639999628067017]]], type: {:f, 32})
iex> Axon.Layers.max_pool(t, kernel_size: 2, channels: :first)
#Nx.Tensor<
  f32[1][3][1]
  [
    [
      [0.051500000059604645],
      [1.6191999912261963],
      [0.7282999753952026]
    ]
  ]
>

Link to this section Layers: Normalization

Link to this function

batch_norm(input, gamma, beta, ra_mean, ra_var, opts \\ [])

View Source

Functional implementation of batch normalization.

Normalizes the input by calculating mean and variance of the input tensor along every dimension but the given :channel_index, and then scaling according to:

$$y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta$$

gamma and beta are often trainable parameters. If training? is true, this method will compute a new mean and variance, and return the updated ra_mean and ra_var. Otherwise, it will just compute batch norm from the given ra_mean and ra_var.

options

Options

  • :epsilon - numerical stability term. $epsilon$ in the above formulation.

  • :channel_index - channel index used to determine reduction axes for mean and variance calculation.

  • :momentum - momentum to use for EMA update.

  • :training? - if true, uses training mode batch norm. Defaults to false.

references

References

Link to this function

group_norm(input, gamma, beta, opts \\ [])

View Source

Functional implementation of group normalization.

Normalizes the input by reshaping input into :num_groups groups and then calculating the mean and variance along every dimension but the input batch dimension.

$$y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta$$

gamma and beta are often trainable parameters. This method does not maintain an EMA of mean and variance.

options

Options

  • :num_groups - Number of groups.

  • :epsilon - numerical stability term. $epsilon$ in the above formulation.

  • :channel_index - channel index used to determine reduction axes and group shape for mean and variance calculation.

references

References

Link to this function

instance_norm(input, gamma, beta, ra_mean, ra_var, opts \\ [])

View Source

Functional implementation of instance normalization.

Normalizes the input by calculating mean and variance of the input tensor along the spatial dimensions of the input.

$$y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta$$

gamma and beta are often trainable parameters. If training? is true, this method will compute a new mean and variance, and return the updated ra_mean and ra_var. Otherwise, it will just compute batch norm from the given ra_mean and ra_var.

options

Options

  • :epsilon - numerical stability term. $epsilon$ in the above formulation.

  • :channel_index - channel index used to determine reduction axes for mean and variance calculation.

  • :momentum - momentum to use for EMA update.

  • :training? - if true, uses training mode batch norm. Defaults to false.

references

References

Link to this function

layer_norm(input, gamma, beta, opts \\ [])

View Source

Functional implementation of layer normalization.

Normalizes the input by calculating mean and variance of the input tensor along the given feature dimension :channel_index.

$$y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta$$

gamma and beta are often trainable parameters. This method does not maintain an EMA of mean and variance.

options

Options

  • :epsilon - numerical stability term. $epsilon$ in the above formulation.

  • :channel_index - channel index used to determine reduction axes for mean and variance calculation.

Link to this section Layers: Shape

Link to this function

flatten(input, opts \\ [])

View Source

Flattens input to shape of {batch, units} by folding outer dimensions.

examples

Examples

iex> Axon.Layers.flatten(Nx.iota({1, 2, 2}, type: {:f, 32}))
#Nx.Tensor<
  f32[1][4]
  [
    [0.0, 1.0, 2.0, 3.0]
  ]
>
Link to this function

resize(input, opts \\ [])

View Source

Resizes a batch of tensors to the given shape using one of a number of sampling methods.

Requires input option :to which should be a tuple specifying the resized spatial dimensions of the input tensor. Input tensor must be at least rank 3, with fixed batch and channel dimensions. Resizing will upsample or downsample using the given resize method.

Supported resize methods are :nearest, :linear, :bilinear, :trilinear, :cubic, :bicubic, :tricubic.

examples

Examples

iex> img = Nx.iota({1, 1, 3, 3}, type: {:f, 32})
iex> Axon.Layers.resize(img, size: {4, 4}, channels: :first)
#Nx.Tensor<
  f32[1][1][4][4]
  [
    [
      [
        [0.0, 1.0, 1.0, 2.0],
        [3.0, 4.0, 4.0, 5.0],
        [3.0, 4.0, 4.0, 5.0],
        [6.0, 7.0, 7.0, 8.0]
      ]
    ]
  ]
>

error-cases

Error cases

iex> img = Nx.iota({1, 1, 3, 3}, type: {:f, 32})
iex> Axon.Layers.resize(img, size: {4, 4}, method: :foo)
** (ArgumentError) expected :method to be either of :nearest, :bilinear, :bicubic, :lanczos3, :lanczos5, got: :foo

Link to this section Functions: Convolutional

Link to this function

conv(input, kernel, bias \\ 0, opts \\ [])

View Source

Functional implementation of a general dimensional convolutional layer.

Convolutional layers can be described as applying a convolution over an input signal composed of several input planes. Intuitively, the input kernel slides output_channels number of filters over the input tensor to extract features from the input tensor.

Convolutional layers are most commonly used in computer vision, but can also be useful when working with sequences and other input signals.

parameter-shapes

Parameter Shapes

  • input - {batch_size, input_channels, input_spatial0, ..., input_spatialN}
  • kernel - {output_channels, input_channels, kernel_spatial0, ..., kernel_spatialN}
  • bias - {} or {output_channels}

options

Options

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :input_dilation - input dilation factor. Equivalent to applying interior padding on the input. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :kernel_dilation - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

examples

Examples

one-dimensional-convolution

One-dimensional convolution

iex> input = Nx.tensor([[[0.1294, -0.6638, 1.0251]], [[ 0.9182,  1.1512, -1.6149]]], type: {:f, 32})
iex> kernel = Nx.tensor([[[-1.5475, 1.2425]], [[0.1871, 0.5458]], [[-0.4488,  0.8879]]], type: {:f, 32})
iex> bias = Nx.tensor([0.7791, 0.1676, 1.5971], type: {:f, 32})
iex> Axon.Layers.conv(input, kernel, bias, channels: :first)
#Nx.Tensor<
  f32[2][3][2]
  [
    [
      [-0.24591797590255737, 3.08001708984375],
      [-0.1704912781715393, 0.6029025316238403],
      [0.9496372938156128, 2.80519962310791]
    ],
    [
      [0.7885514497756958, -3.0088953971862793],
      [0.9677201509475708, -0.4984228312969208],
      [2.207162380218506, -0.3534282445907593]
    ]
  ]
>

two-dimensional-convolution

Two-dimensional convolution

iex> input = Nx.tensor([[[[-1.0476, -0.5041], [-0.9336, 1.5907]]]], type: {:f, 32})
iex> kernel = Nx.tensor([
...>  [[[0.7514, 0.7356], [1.3909,  0.6800]]],
...>  [[[-0.3450,  0.4551], [-0.6275, -0.9875]]],
...>  [[[1.8587, 0.4722], [0.6058, -1.0301]]]
...> ], type: {:f, 32})
iex> bias = Nx.tensor([1.9564, 0.2822, -0.5385], type: {:f, 32})
iex> Axon.Layers.conv(input, kernel, bias, channels: :first)
#Nx.Tensor<
  f32[1][3][1][1]
  [
    [
      [
        [0.5815491676330566]
      ],
      [
        [-0.5707762241363525]
      ],
      [
        [-4.927865028381348]
      ]
    ]
  ]
>

three-dimensional-convolution

Three-dimensional convolution

iex> input = Nx.tensor([[[[[-0.6497], [1.0939]], [[-2.5465], [0.7801]]]]], type: {:f, 32})
iex> kernel = Nx.tensor([
...>  [[[[ 0.7390], [-0.0927]], [[-0.8675], [-0.9209]]]],
...>  [[[[-0.6638], [0.4341]], [[0.6368], [1.1846]]]]
...> ], type: {:f, 32})
iex> bias = Nx.tensor([-0.4101,  0.1776], type: {:f, 32})
iex> Axon.Layers.conv(input, kernel, bias, channels: :first)
#Nx.Tensor<
  f32[1][2][1][1][1]
  [
    [
      [
        [
          [0.49906185269355774]
        ]
      ],
      [
        [
          [0.38622811436653137]
        ]
      ]
    ]
  ]
>
Link to this function

conv_transpose(input, kernel, bias \\ 0, opts \\ [])

View Source

Functional implementation of a general dimensional transposed convolutional layer.

Note: This layer is currently implemented as a fractionally strided convolution by padding the input tensor. Please open an issue if you'd like this behavior changed.

Transposed convolutions are sometimes (incorrectly) referred to as deconvolutions because it "reverses" the spatial dimensions of a normal convolution. Transposed convolutions are a form of upsampling - they produce larger spatial dimensions than the input tensor. They can be thought of as a convolution in reverse - and are sometimes implemented as the backward pass of a normal convolution.

options

Options

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :input_dilation - input dilation factor. Equivalent to applying interior padding on the input. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :kernel_dilation - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

examples

Examples

iex> input = Nx.iota({1, 3, 3}, type: {:f, 32})
iex> kernel = Nx.iota({6, 3, 2}, type: {:f, 32})
iex> bias = Nx.tensor(1.0, type: {:f, 32})
iex> Axon.Layers.conv_transpose(input, kernel, bias, channels: :first)
#Nx.Tensor<
  f32[1][6][4]
  [
    [
      [40.0, 79.0, 94.0, 43.0],
      [94.0, 205.0, 256.0, 133.0],
      [148.0, 331.0, 418.0, 223.0],
      [202.0, 457.0, 580.0, 313.0],
      [256.0, 583.0, 742.0, 403.0],
      [310.0, 709.0, 904.0, 493.0]
    ]
  ]
>

references

References

Link to this function

depthwise_conv(inputs, kernel, bias \\ 0, opts \\ [])

View Source

Functional implementation of a general dimensional depthwise convolution.

Depthwise convolutions apply a single convolutional filter to each input channel. This is done by setting feature_group_size equal to the number of input channels. This will split the output_channels into input_channels number of groups and convolve the grouped kernel channels over the corresponding input channel.

parameter-shapes

Parameter Shapes

  • input - {batch_size, input_channels, input_spatial0, ..., input_spatialN}
  • kernel - {output_channels, 1, kernel_spatial0, ..., kernel_spatialN}
  • bias - {output_channels} or {}

output_channels must be a multiple of the input channels.

options

Options

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :input_dilation - input dilation factor. Equivalent to applying interior padding on the input. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :kernel_dilation - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

separable_conv2d(input, k1, b1, k2, b2, opts \\ [])

View Source

Functional implementation of a 2-dimensional separable depthwise convolution.

The 2-d depthwise separable convolution performs 2 depthwise convolutions each over 1 spatial dimension of the input.

parameter-shapes

Parameter Shapes

  • input - {batch_size, input_channels, input_spatial0, ..., input_spatialN}
  • k1 - {output_channels, 1, kernel_spatial0, 1}
  • b1 - {output_channels} or {}
  • k2 - {output_channels, 1, 1, kernel_spatial1}
  • b2 - {output_channels} or {}

output_channels must be a multiple of the input channels.

options

Options

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :input_dilation - input dilation factor. Equivalent to applying interior padding on the input. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :kernel_dilation - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

references

References

Link to this function

separable_conv3d(input, k1, b1, k2, b2, k3, b3, opts \\ [])

View Source

Functional implementation of a 3-dimensional separable depthwise convolution.

The 3-d depthwise separable convolution performs 3 depthwise convolutions each over 1 spatial dimension of the input.

parameter-shapes

Parameter Shapes

  • input - {batch_size, input_channels, input_spatial0, input_spatial1, input_spatial2}
  • k1 - {output_channels, 1, kernel_spatial0, 1, 1}
  • b1 - {output_channels} or {}
  • k2 - {output_channels, 1, 1, kernel_spatial1, 1}
  • b2 - {output_channels} or {}
  • k3 - {output_channels, 1, 1, 1, 1, kernel_spatial2}
  • b3 - {output_channels} or {}

output_channels must be a multiple of the input channels.

options

Options

  • :strides - kernel strides. Can be a scalar or a list who's length matches the number of spatial dimensions in the input tensor. Defaults to 1.

  • :padding - zero padding on the input. Can be one of :valid, :same or a general padding configuration without interior padding for each spatial dimension of the input.

  • :input_dilation - input dilation factor. Equivalent to applying interior padding on the input. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :kernel_dilation - kernel dilation factor. Equivalent to applying interior padding on the kernel. The amount of interior padding applied is given by kernel_dilation - 1. Defaults to 1 or no dilation.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

references

References

Link to this section Functions

Link to this function

conv_lstm(input, hidden_state, input_kernel, hidden_kernel, bias \\ [], opts \\ [])

View Source
Link to this function

conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ [])

View Source

ConvLSTM Cell.

When combined with Axon.Layers.*_unroll, implements a ConvLSTM-based RNN. More memory efficient than traditional LSTM.

options

Options

  • :strides - convolution strides. Defaults to 1.

  • :padding - convolution padding. Defaults to :same.

references

References

Link to this function

dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)

View Source

Dynamically unrolls an RNN.

Unrolls implement a scan operation which applies a transformation on the leading axis of input_sequence carrying some state. In this instance cell_fn is an RNN cell function such as lstm_cell or gru_cell.

This function will make use of an defn while-loop such and thus may be more efficient for long sequences.

Link to this function

gru(input, hidden_state, input_kernel, hidden_kernel, bias \\ [], opts \\ [])

View Source
Link to this function

gru_cell(input, carry, input_kernel, hidden_kernel, bias, gate_fn \\ &Axon.Activations.sigmoid/1, activation_fn \\ &Axon.Activations.tanh/1)

View Source

GRU Cell.

When combined with Axon.Layers.*_unroll, implements a GRU-based RNN. More memory efficient than traditional LSTM.

references

References

Link to this function

hard_sigmoid(input, opts \\ [])

View Source
Link to this function

hard_silu(input, opts \\ [])

View Source
Link to this function

leaky_relu(input, opts \\ [])

View Source
Link to this function

log_softmax(input, opts \\ [])

View Source
Link to this function

log_sumexp(input, opts \\ [])

View Source
Link to this function

lstm(input, hidden_state, input_kernel, hidden_kernel, bias \\ [], opts \\ [])

View Source
Link to this function

lstm_cell(input, carry, input_kernel, hidden_kernel, bias, gate_fn \\ &Axon.Activations.sigmoid/1, activation_fn \\ &Axon.Activations.tanh/1)

View Source

LSTM Cell.

When combined with Axon.Layers.*_unroll, implements a LSTM-based RNN. More memory efficient than traditional LSTM.

references

References

Link to this function

multiply(inputs, opts \\ [])

View Source
Link to this function

padding_config_transform(config, channels)

View Source
Link to this function

softmax(input, opts \\ [])

View Source
Link to this function

static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)

View Source

Statically unrolls an RNN.

Unrolls implement a scan operation which applies a transformation on the leading axis of input_sequence carrying some state. In this instance cell_fn is an RNN cell function such as lstm_cell or gru_cell.

This function inlines the unrolling of the sequence such that the entire operation appears as a part of the compilation graph. This makes it suitable for shorter sequences.

Link to this function

subtract(inputs, opts \\ [])

View Source