View Source Axon.Recurrent (Axon v0.1.0)

Functional implementations of common recurrent neural network routines.

Recurrent Neural Networks are commonly used for working with sequences of data where there is some level of dependence between outputs at different timesteps.

This module contains 3 RNN Cell functions and methods to "unroll" cells over an entire sequence. Each cell function returns a tuple:

{new_carry, output}

Where new_carry is an updated carry state and output is the output for a singular timestep. In order to apply an RNN across multiple timesteps, you need to use either static_unroll or dynamic_unroll (coming soon).

Unrolling an RNN is equivalent to a map_reduce or scan starting from an initial carry state and ending with a final carry state and an output sequence.

All of the functions in this module are implemented as numerical functions and can be JIT or AOT compiled with any supported Nx compiler.

Link to this section Summary

Link to this section Functions

Link to this function

conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ [])

View Source

ConvLSTM Cell.

When combined with Axon.Recurrent.*_unroll, implements a ConvLSTM-based RNN. More memory efficient than traditional LSTM.

options

Options

  • :strides - convolution strides. Defaults to 1.

  • :padding - convolution padding. Defaults to :same.

references

References

Link to this function

dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)

View Source

Dynamically unrolls an RNN.

Unrolls implement a scan operation which applies a transformation on the leading axis of input_sequence carrying some state. In this instance cell_fn is an RNN cell function such as lstm_cell or gru_cell.

This function will make use of an defn while-loop such and thus may be more efficient for long sequences.

Link to this function

gru_cell(input, carry, input_kernel, hidden_kernel, bias, gate_fn \\ &sigmoid/1, activation_fn \\ &tanh/1)

View Source

GRU Cell.

When combined with Axon.Recurrent.*_unroll, implements a GRU-based RNN. More memory efficient than traditional LSTM.

references

References

Link to this function

lstm_cell(input, carry, input_kernel, hidden_kernel, bias, gate_fn \\ &sigmoid/1, activation_fn \\ &tanh/1)

View Source

LSTM Cell.

When combined with Axon.Recurrent.*_unroll, implements a LSTM-based RNN. More memory efficient than traditional LSTM.

references

References

Link to this function

static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)

View Source

Statically unrolls an RNN.

Unrolls implement a scan operation which applies a transformation on the leading axis of input_sequence carrying some state. In this instance cell_fn is an RNN cell function such as lstm_cell or gru_cell.

This function inlines the unrolling of the sequence such that the entire operation appears as a part of the compilation graph. This makes it suitable for shorter sequences.