diff --git a/crates/compiler/derive/src/lib.rs b/crates/compiler/derive/src/lib.rs index b947a677ed..4d9bf1ce50 100644 --- a/crates/compiler/derive/src/lib.rs +++ b/crates/compiler/derive/src/lib.rs @@ -3,6 +3,8 @@ use std::iter::once; use std::sync::{Arc, Mutex}; +use roc_can::abilities::SpecializationLambdaSets; +use roc_can::expr::Expr; use roc_can::pattern::Pattern; use roc_can::{def::Def, module::ExposedByModule}; use roc_collections::MutMap; @@ -31,7 +33,7 @@ pub fn synth_var(subs: &mut Subs, content: Content) -> Variable { /// Map of [`DeriveKey`]s to their derived symbols. #[derive(Debug, Default)] pub struct DerivedModule { - map: MutMap, + map: MutMap, subs: Subs, derived_ident_ids: IdentIds, @@ -45,14 +47,25 @@ pub struct StolenFromDerived { pub ident_ids: IdentIds, } +pub(crate) struct DerivedBody { + pub body: Expr, + pub body_type: Variable, + /// mapping of lambda set region -> the specialization lambda set for this derived body + pub specialization_lambda_sets: SpecializationLambdaSets, +} + fn build_derived_body( derived_subs: &mut Subs, derived_ident_ids: &mut IdentIds, exposed_by_module: &ExposedByModule, derived_symbol: Symbol, derive_key: DeriveKey, -) -> Def { - let (body, var) = match derive_key { +) -> (Def, SpecializationLambdaSets) { + let DerivedBody { + body, + body_type, + specialization_lambda_sets, + } = match derive_key { DeriveKey::ToEncoder(to_encoder_key) => { let mut env = encoding::Env { subs: derived_subs, @@ -64,13 +77,15 @@ fn build_derived_body( DeriveKey::Decoding => todo!(), }; - Def { + let def = Def { loc_pattern: Loc::at_zero(Pattern::Identifier(derived_symbol)), loc_expr: Loc::at_zero(body), - expr_var: var, - pattern_vars: once((derived_symbol, var)).collect(), + expr_var: body_type, + pattern_vars: once((derived_symbol, body_type)).collect(), annotation: None, - } + }; + + (def, specialization_lambda_sets) } impl DerivedModule { @@ -78,7 +93,7 @@ impl DerivedModule { &mut self, exposed_by_module: &ExposedByModule, key: DeriveKey, - ) -> &(Symbol, Def) { + ) -> &(Symbol, Def, SpecializationLambdaSets) { #[cfg(debug_assertions)] { debug_assert!(!self.stolen, "attempting to add to stolen symbols!"); @@ -106,7 +121,7 @@ impl DerivedModule { }; let derived_symbol = Symbol::new(DERIVED_MODULE, ident_id); - let derived_def = build_derived_body( + let (derived_def, specialization_lsets) = build_derived_body( &mut self.subs, &mut self.derived_ident_ids, exposed_by_module, @@ -114,11 +129,13 @@ impl DerivedModule { key.clone(), ); - (derived_symbol, derived_def) + (derived_symbol, derived_def, specialization_lsets) }) } - pub fn iter_all(&self) -> impl Iterator { + pub fn iter_all( + &self, + ) -> impl Iterator { self.map.iter() }