Nasty.Semantic.Coreference.Neural.SpanEnumeration (Nasty v0.3.0)

View Source

Span enumeration and pruning for end-to-end coreference resolution.

Generates all possible spans up to a maximum length, scores them, and prunes to the top-K candidates. This is the first stage of the span-based end-to-end model.

Workflow

  1. Enumerate all spans up to max_length
  2. Compute span representations from LSTM states
  3. Score spans using feedforward network
  4. Keep only top-K highest scoring spans

Example

# Encode document with BiLSTM
lstm_outputs = encode_document(doc)

# Enumerate and score spans
{:ok, spans} = SpanEnumeration.enumerate_and_prune(
  lstm_outputs,
  max_length: 10,
  top_k: 50
)

Summary

Functions

Build Axon model for span scoring.

Enumerate all possible spans and prune to top-K.

Enumerate all spans up to max_length.

Types

span()

@type span() :: %{
  start_idx: non_neg_integer(),
  end_idx: non_neg_integer(),
  score: float(),
  representation: Nx.Tensor.t()
}

Functions

build_span_scorer(opts \\ [])

Build Axon model for span scoring.

Parameters

  • opts - Model options

Options

  • :hidden_dim - LSTM hidden dimension (default: 256)
  • :width_emb_dim - Width embedding dimension (default: 20)
  • :scorer_hidden - Scorer hidden layers (default: [256, 128])
  • :dropout - Dropout rate (default: 0.3)

Returns

  • Axon model

enumerate_and_prune(lstm_outputs, opts \\ [])

@spec enumerate_and_prune(
  Nx.Tensor.t(),
  keyword()
) :: {:ok, [span()]}

Enumerate all possible spans and prune to top-K.

Parameters

  • lstm_outputs - LSTM hidden states [seq_len, hidden_dim]
  • opts - Options

Options

  • :max_length - Maximum span length in tokens (default: 10)
  • :top_k - Number of spans to keep per sentence (default: 50)
  • :scorer_model - Trained span scorer model (optional)
  • :scorer_params - Scorer parameters (optional)

Returns

  • {:ok, spans} - List of top-K scored spans

enumerate_spans(lstm_outputs, max_length)

@spec enumerate_spans(Nx.Tensor.t(), pos_integer()) :: [
  {non_neg_integer(), non_neg_integer()}
]

Enumerate all spans up to max_length.

Returns list of span indices: [{start, end}, ...]

span_representation(lstm_outputs, start_idx, end_idx, width_embeddings \\ nil)

@spec span_representation(
  Nx.Tensor.t(),
  non_neg_integer(),
  non_neg_integer(),
  Nx.Tensor.t() | nil
) :: Nx.Tensor.t()

Compute span representation from LSTM states.

Representation is concatenation of:

  • Start state
  • End state
  • Attention-weighted average over span
  • Span width embedding

Parameters

  • lstm_outputs - LSTM hidden states [seq_len, hidden_dim]
  • start_idx - Start index
  • end_idx - End index (inclusive)
  • width_embeddings - Optional width embedding tensor [max_width, width_dim]

Returns

  • Span representation tensor [span_dim]