Nasty.Semantic.Coreference.Neural.Trainer (Nasty v0.3.0)
View SourceTraining pipeline for neural coreference resolution.
Trains mention encoder and pair scorer models end-to-end using binary cross-entropy loss with early stopping on dev set.
Example
# Train models
{:ok, models, history} = Trainer.train(
training_data,
dev_data,
vocab,
epochs: 20,
batch_size: 32,
learning_rate: 0.001
)
# Save trained models
Trainer.save_models(models, "priv/models/en/coref_neural")
Summary
Functions
Evaluate models on dataset.
Load trained models from disk.
Save trained models to disk.
Train neural coreference models.
Types
@type history() :: %{ train_loss: [float()], train_acc: [float()], dev_loss: [float()], dev_acc: [float()], best_epoch: pos_integer() }
@type training_data() :: [ {Nasty.AST.Semantic.Mention.t(), Nasty.AST.Semantic.Mention.t(), 0 | 1} ]
Functions
Evaluate models on dataset.
Parameters
models- Trained modelsparams- Model parametersdata- Evaluation datavocab- Vocabulary map
Returns
Map with loss and accuracy
Load trained models from disk.
Parameters
base_path- Base path without extension
Returns
{:ok, models, params, vocab}- Loaded models{:error, reason}- Load error
Save trained models to disk.
Parameters
models- Models to saveparams- Model parametersvocab- Vocabularybase_path- Base path without extension
Example
Trainer.save_models(models, params, vocab, "priv/models/en/coref")
# Creates:
# priv/models/en/coref_encoder.axon
# priv/models/en/coref_scorer.axon
# priv/models/en/coref_vocab.etf
@spec train(training_data(), training_data(), map(), keyword()) :: {:ok, models(), params(), history()} | {:error, term()}
Train neural coreference models.
Parameters
training_data- List of {mention1, mention2, label} tuplesdev_data- Development set for early stoppingvocab- Vocabulary mapopts- Training options
Options
:epochs- Number of training epochs (default: 20):batch_size- Batch size (default: 32):learning_rate- Learning rate (default: 0.001):hidden_dim- LSTM hidden dimension (default: 128):dropout- Dropout rate (default: 0.3):patience- Early stopping patience (default: 3):clip_norm- Gradient clipping norm (default: 5.0)
Returns
{:ok, models, params, history}- Trained models and history{:error, reason}- Training error