mirror of
https://github.com/Automattic/harper.git
synced 2025-08-04 18:48:02 +00:00
feat(core): use dropout
This commit is contained in:
parent
ae765a5816
commit
31d1988ea0
1 changed files with 16 additions and 6 deletions
|
@ -1,6 +1,7 @@
|
|||
use crate::{UPOS, chunker::Chunker};
|
||||
use burn::backend::Autodiff;
|
||||
use burn::nn::loss::{MseLoss, Reduction};
|
||||
use burn::nn::{Dropout, DropoutConfig};
|
||||
use burn::optim::{GradientsParams, Optimizer};
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkBytesRecorder, NamedMpkFileRecorder, Recorder};
|
||||
use burn::tensor::TensorData;
|
||||
|
@ -24,6 +25,7 @@ struct NpModel<B: Backend> {
|
|||
embedding: burn::nn::Embedding<B>,
|
||||
lstm: burn::nn::BiLstm<B>,
|
||||
linear: burn::nn::Linear<B>,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl<B: Backend> NpModel<B> {
|
||||
|
@ -31,13 +33,21 @@ impl<B: Backend> NpModel<B> {
|
|||
Self {
|
||||
embedding: EmbeddingConfig::new(vocab, embed_dim).init(device),
|
||||
lstm: BiLstmConfig::new(embed_dim, embed_dim, false).init(device),
|
||||
// Multiply by two because the BiLSTM emits double the hidden parameters
|
||||
linear: LinearConfig::new(embed_dim * 2, 1).init(device),
|
||||
dropout: DropoutConfig::new(0.5).init(),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 2> {
|
||||
let x = self.embedding.forward(input);
|
||||
let (x, _) = self.lstm.forward(x, None);
|
||||
fn forward(&self, input: Tensor<B, 2, Int>, use_dropout: bool) -> Tensor<B, 2> {
|
||||
let mut x = self.embedding.forward(input);
|
||||
if use_dropout {
|
||||
x = self.dropout.forward(x);
|
||||
}
|
||||
let (mut x, _) = self.lstm.forward(x, None);
|
||||
if use_dropout {
|
||||
x = self.dropout.forward(x);
|
||||
}
|
||||
let x = self.linear.forward(x);
|
||||
x.squeeze::<2>(2)
|
||||
}
|
||||
|
@ -145,7 +155,7 @@ impl<B: Backend + AutodiffBackend> BurnChunker<B> {
|
|||
let x_tensor = util.to_tensor(x);
|
||||
let y_tensor = util.to_label(y);
|
||||
|
||||
let logits = model.forward(x_tensor);
|
||||
let logits = model.forward(x_tensor, true);
|
||||
total_correct += logits
|
||||
.to_data()
|
||||
.iter()
|
||||
|
@ -206,7 +216,7 @@ impl<B: Backend + AutodiffBackend> BurnChunker<B> {
|
|||
for (x, y) in sents.iter().zip(labs.iter()) {
|
||||
let x_tensor = self.to_tensor(x);
|
||||
|
||||
let logits = model.forward(x_tensor);
|
||||
let logits = model.forward(x_tensor, false);
|
||||
total_correct += logits
|
||||
.to_data()
|
||||
.iter()
|
||||
|
@ -319,7 +329,7 @@ impl<B: Backend + AutodiffBackend> Chunker for BurnChunker<B> {
|
|||
}
|
||||
|
||||
let tensor = self.to_tensor(sentence);
|
||||
let prob = self.model.forward(tensor);
|
||||
let prob = self.model.forward(tensor, false);
|
||||
prob.to_data().iter().map(|p: f32| p > 0.5).collect()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue