Nasty.Semantic.Coreference.Neural.E2ETrainer (Nasty v0.3.0)

View Source

Training pipeline for end-to-end span-based coreference resolution.

Trains the model with joint optimization of:

  1. Span detection (mention vs non-mention)
  2. Pairwise coreference (coreferent vs not)

Loss = span_weight span_loss + coref_weight coref_loss

Includes early stopping based on CoNLL F1 score on dev set.

Example

# Train model
{:ok, models, params, history} = E2ETrainer.train(
  train_data,
  dev_data,
  vocab,
  epochs: 25,
  batch_size: 16,
  learning_rate: 0.0005
)

# Save models
E2ETrainer.save_models(models, params, vocab, "priv/models/en/e2e_coref")

Summary

Functions

Load trained models from disk.

Save trained models to disk.

Train end-to-end coreference model.

Types

models()

@type models() :: %{
  encoder: Axon.t(),
  span_scorer: Axon.t(),
  pair_scorer: Axon.t(),
  width_embeddings: Axon.t(),
  config: map()
}

params()

@type params() :: %{
  encoder: map(),
  span_scorer: map(),
  pair_scorer: map(),
  width_embeddings: map()
}

Functions

load_models(base_path)

@spec load_models(Path.t()) :: {:ok, models(), params(), map()} | {:error, term()}

Load trained models from disk.

Parameters

  • base_path - Base path where models were saved

Returns

  • {:ok, models, params, vocab} - Loaded models
  • {:error, reason} - Load error

save_models(models, params, vocab, base_path)

@spec save_models(models(), params(), map(), Path.t()) :: :ok | {:error, term()}

Save trained models to disk.

Parameters

  • models - Model structures
  • params - Model parameters
  • vocab - Vocabulary map
  • base_path - Base path for saving (directory will be created)

Returns

  • :ok - Success
  • {:error, reason} - Save error

train(train_data, dev_data, vocab, opts \\ [])

@spec train([map()], [map()], map(), keyword()) ::
  {:ok, models(), params(), map()} | {:error, term()}

Train end-to-end coreference model.

Parameters

  • train_data - Training data (spans + labels)
  • dev_data - Development data (for early stopping)
  • vocab - Vocabulary map
  • opts - Training options

Options

  • :epochs - Number of epochs (default: 25)
  • :batch_size - Batch size (default: 16)
  • :learning_rate - Learning rate (default: 0.0005)
  • :hidden_dim - LSTM hidden dimension (default: 256)
  • :dropout - Dropout rate (default: 0.3)
  • :patience - Early stopping patience (default: 3)
  • :span_loss_weight - Weight for span loss (default: 0.3)
  • :coref_loss_weight - Weight for coref loss (default: 0.7)
  • :max_span_width - Maximum span width (default: 10)
  • :top_k_spans - Keep top K spans per sentence (default: 50)

Returns

  • {:ok, models, params, history} - Trained models and history
  • {:error, reason} - Training error