Nasty.Semantic.Coreference.Neural.SpanEnumeration (Nasty v0.3.0)
View SourceSpan enumeration and pruning for end-to-end coreference resolution.
Generates all possible spans up to a maximum length, scores them, and prunes to the top-K candidates. This is the first stage of the span-based end-to-end model.
Workflow
- Enumerate all spans up to max_length
- Compute span representations from LSTM states
- Score spans using feedforward network
- Keep only top-K highest scoring spans
Example
# Encode document with BiLSTM
lstm_outputs = encode_document(doc)
# Enumerate and score spans
{:ok, spans} = SpanEnumeration.enumerate_and_prune(
lstm_outputs,
max_length: 10,
top_k: 50
)
Summary
Functions
Build Axon model for span scoring.
Enumerate all possible spans and prune to top-K.
Enumerate all spans up to max_length.
Compute span representation from LSTM states.
Types
@type span() :: %{ start_idx: non_neg_integer(), end_idx: non_neg_integer(), score: float(), representation: Nx.Tensor.t() }
Functions
Build Axon model for span scoring.
Parameters
opts- Model options
Options
:hidden_dim- LSTM hidden dimension (default: 256):width_emb_dim- Width embedding dimension (default: 20):scorer_hidden- Scorer hidden layers (default: [256, 128]):dropout- Dropout rate (default: 0.3)
Returns
- Axon model
@spec enumerate_and_prune( Nx.Tensor.t(), keyword() ) :: {:ok, [span()]}
Enumerate all possible spans and prune to top-K.
Parameters
lstm_outputs- LSTM hidden states [seq_len, hidden_dim]opts- Options
Options
:max_length- Maximum span length in tokens (default: 10):top_k- Number of spans to keep per sentence (default: 50):scorer_model- Trained span scorer model (optional):scorer_params- Scorer parameters (optional)
Returns
{:ok, spans}- List of top-K scored spans
@spec enumerate_spans(Nx.Tensor.t(), pos_integer()) :: [ {non_neg_integer(), non_neg_integer()} ]
Enumerate all spans up to max_length.
Returns list of span indices: [{start, end}, ...]
@spec span_representation( Nx.Tensor.t(), non_neg_integer(), non_neg_integer(), Nx.Tensor.t() | nil ) :: Nx.Tensor.t()
Compute span representation from LSTM states.
Representation is concatenation of:
- Start state
- End state
- Attention-weighted average over span
- Span width embedding
Parameters
lstm_outputs- LSTM hidden states [seq_len, hidden_dim]start_idx- Start indexend_idx- End index (inclusive)width_embeddings- Optional width embedding tensor [max_width, width_dim]
Returns
- Span representation tensor [span_dim]