feat(core): use dropout

This commit is contained in:
Elijah Potter 2025-07-01 15:08:34 +00:00
parent ae765a5816
commit 31d1988ea0

View file

@ -1,6 +1,7 @@
use crate::{UPOS, chunker::Chunker};
use burn::backend::Autodiff;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Dropout, DropoutConfig};
use burn::optim::{GradientsParams, Optimizer};
use burn::record::{FullPrecisionSettings, NamedMpkBytesRecorder, NamedMpkFileRecorder, Recorder};
use burn::tensor::TensorData;
@ -24,6 +25,7 @@ struct NpModel<B: Backend> {
embedding: burn::nn::Embedding<B>,
lstm: burn::nn::BiLstm<B>,
linear: burn::nn::Linear<B>,
dropout: Dropout,
}
impl<B: Backend> NpModel<B> {
@ -31,13 +33,21 @@ impl<B: Backend> NpModel<B> {
Self {
embedding: EmbeddingConfig::new(vocab, embed_dim).init(device),
lstm: BiLstmConfig::new(embed_dim, embed_dim, false).init(device),
// Multiply by two because the BiLSTM emits double the hidden parameters
linear: LinearConfig::new(embed_dim * 2, 1).init(device),
dropout: DropoutConfig::new(0.5).init(),
}
}
fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 2> {
let x = self.embedding.forward(input);
let (x, _) = self.lstm.forward(x, None);
fn forward(&self, input: Tensor<B, 2, Int>, use_dropout: bool) -> Tensor<B, 2> {
let mut x = self.embedding.forward(input);
if use_dropout {
x = self.dropout.forward(x);
}
let (mut x, _) = self.lstm.forward(x, None);
if use_dropout {
x = self.dropout.forward(x);
}
let x = self.linear.forward(x);
x.squeeze::<2>(2)
}
@ -145,7 +155,7 @@ impl<B: Backend + AutodiffBackend> BurnChunker<B> {
let x_tensor = util.to_tensor(x);
let y_tensor = util.to_label(y);
let logits = model.forward(x_tensor);
let logits = model.forward(x_tensor, true);
total_correct += logits
.to_data()
.iter()
@ -206,7 +216,7 @@ impl<B: Backend + AutodiffBackend> BurnChunker<B> {
for (x, y) in sents.iter().zip(labs.iter()) {
let x_tensor = self.to_tensor(x);
let logits = model.forward(x_tensor);
let logits = model.forward(x_tensor, false);
total_correct += logits
.to_data()
.iter()
@ -319,7 +329,7 @@ impl<B: Backend + AutodiffBackend> Chunker for BurnChunker<B> {
}
let tensor = self.to_tensor(sentence);
let prob = self.model.forward(tensor);
let prob = self.model.forward(tensor, false);
prob.to_data().iter().map(|p: f32| p > 0.5).collect()
}
}