Nasty.Semantic.Coreference.Neural.E2ETrainer (Nasty v0.3.0)
View SourceTraining pipeline for end-to-end span-based coreference resolution.
Trains the model with joint optimization of:
- Span detection (mention vs non-mention)
- Pairwise coreference (coreferent vs not)
Loss = span_weight span_loss + coref_weight coref_loss
Includes early stopping based on CoNLL F1 score on dev set.
Example
# Train model
{:ok, models, params, history} = E2ETrainer.train(
train_data,
dev_data,
vocab,
epochs: 25,
batch_size: 16,
learning_rate: 0.0005
)
# Save models
E2ETrainer.save_models(models, params, vocab, "priv/models/en/e2e_coref")
Summary
Functions
Load trained models from disk.
Save trained models to disk.
Train end-to-end coreference model.
Types
Functions
Load trained models from disk.
Parameters
base_path- Base path where models were saved
Returns
{:ok, models, params, vocab}- Loaded models{:error, reason}- Load error
Save trained models to disk.
Parameters
models- Model structuresparams- Model parametersvocab- Vocabulary mapbase_path- Base path for saving (directory will be created)
Returns
:ok- Success{:error, reason}- Save error
@spec train([map()], [map()], map(), keyword()) :: {:ok, models(), params(), map()} | {:error, term()}
Train end-to-end coreference model.
Parameters
train_data- Training data (spans + labels)dev_data- Development data (for early stopping)vocab- Vocabulary mapopts- Training options
Options
:epochs- Number of epochs (default: 25):batch_size- Batch size (default: 16):learning_rate- Learning rate (default: 0.0005):hidden_dim- LSTM hidden dimension (default: 256):dropout- Dropout rate (default: 0.3):patience- Early stopping patience (default: 3):span_loss_weight- Weight for span loss (default: 0.3):coref_loss_weight- Weight for coref loss (default: 0.7):max_span_width- Maximum span width (default: 10):top_k_spans- Keep top K spans per sentence (default: 50)
Returns
{:ok, models, params, history}- Trained models and history{:error, reason}- Training error