TabNet - Attentive Interpretable Tabular Learning.
TabNet uses sequential attention to select which features to reason about at each decision step. This provides instance-wise feature selection, making the model inherently interpretable while maintaining high performance on tabular data.
Architecture
Input [batch, input_size]
|
v
+--------------------------------------+
| Initial BN |
+--------------------------------------+
|
v
+--------------------------------------+
| Step 1: |
| Attention: select features via |
| sparse mask M = sparsemax(...) |
| Transform: process selected feats |
| Split: h_step -> decision + next |
+--------------------------------------+
| (repeat num_steps)
v
+--------------------------------------+
| Aggregate: sum decision outputs |
+--------------------------------------+
|
v
Output [batch, hidden_size or num_classes]Feature Selection
At each step, TabNet uses an attention transformer to produce a mask that selects relevant input features. The relaxation factor gamma controls how much previously attended features can be reused.
Usage
model = TabNet.build(
input_size: 128,
hidden_size: 64,
num_steps: 3,
relaxation_factor: 1.5,
num_classes: 10
)References
- Arik & Pfister, "TabNet: Attentive Interpretable Tabular Learning" (AAAI 2021)
- https://arxiv.org/abs/1908.07442
Summary
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:input_size, pos_integer()} | {:num_classes, pos_integer() | nil} | {:num_steps, pos_integer()} | {:relaxation_factor, float()}
Options for build/1.
Functions
Build a TabNet model.
Options
:input_size- Input feature dimension (required):hidden_size- Hidden dimension for processing (default: 64):num_steps- Number of sequential attention steps (default: 3):relaxation_factor- Controls feature reuse across steps (default: 1.5):num_classes- If provided, adds classification head (default: nil):dropout- Dropout rate (default: 0.0)
Returns
An Axon model: [batch, input_size] -> [batch, hidden_size or num_classes]
@spec output_size(keyword()) :: pos_integer()
Get the output size of a TabNet model.